From 68563dc0195084b56e45e6ee63b1416e7616377b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 7 Jan 2025 21:37:41 +0100 Subject: [PATCH 1/4] refactor recomputation to work with tags --- thunder/core/trace_interpreter.py | 5 +- thunder/core/transforms.py | 132 +++++++++++++++++++--------- thunder/executors/torch_autograd.py | 7 +- thunder/tests/test_grad.py | 29 ++++++ 4 files changed, 129 insertions(+), 44 deletions(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index b38cd01ac7..24e196bb0a 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -6,7 +6,7 @@ from thunder.core.trace import VariableInterface, from_trace, tracectx from thunder.core.baseutils import ProxyInterface, TensorProxyInterface from thunder.core.utils import safe_map_flat, sequencify -from thunder.core.proxies import variableify +from thunder.core.proxies import variableify, ProxyTag from thunder.core.transform_common import VJPDual @@ -183,6 +183,9 @@ def do_swap(v): for new_bsym in new_bsyms: # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? + for o in new_bsym.flat_proxy_outs: + if variableify(o) not in swap_map: + o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) new_trace.bound_symbols.append( new_bsym.from_bsym_swap_proxies(swap_map).from_bsym( source_filename=bsym.source_filename, source_positions=bsym.source_positions diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index e28821a8b5..09c7ba1509 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -38,7 +38,7 @@ from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass -from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol +from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, has_tags from thunder.core.trace import TraceCtx as Trace from thunder.core.trace import VariableInterface as Variable from thunder.core.trace import ( @@ -59,6 +59,7 @@ unzip2, const_as, sequencify, + OrderedSet, ProxyDict, find_producer_symbols, ) @@ -92,6 +93,7 @@ TraceTag.register_tag("AUGMENTED_FORWARD") +ProxyTag.register_tag("RECOMPUTE_IN_BACKWARD") # TODO This should be a partial of thunder.trace, but that would cause a circular import @@ -2991,6 +2993,7 @@ def unpacking_fn(saved_for_backward, cotangents): ) ) backward_trace.bound_symbols = list((*unpacking_trace.bound_symbols[:-1], *backward_trace_bsyms_without_unpacking)) + backward_trace.scopes[0] = backward_trace.bound_symbols def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces: @@ -3153,6 +3156,8 @@ def backward_fn(saved_for_backward, cotangents): enable_saved_for_backward_recomputation: None | bool = get_compile_option( "enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation." ) + if enable_saved_for_backward_recomputation is None: + enable_saved_for_backward_recomputation = True if enable_saved_for_backward_recomputation: forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace) @@ -3171,11 +3176,32 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr start_time_ns = time.perf_counter_ns() + cd = get_compile_data() + have_nvfuser = any([ex.name == "nvfuser" for ex in cd.executors_list]) if cd is not None else False + if have_nvfuser: + from thunder.core.rematerialization import replace_uniform + + fwd_trace = replace_uniform(fwd_trace) + saved_for_bw = get_saved_for_backward_tensors(fwd_trace) - fwd_trace_args = {variableify(j) for j in fwd_trace.args} - old_saved_for_bwd = {variableify(j) for j in saved_for_bw} + fwd_trace_args = OrderedSet(variableify(j) for j in fwd_trace.args) + old_saved_for_bwd = OrderedSet(variableify(j) for j in saved_for_bw) + + all_proxies = fwd_trace_args.copy() + all_recomputable_proxies = OrderedSet() + + proxy_names_to_producers = {} + for bsym in fwd_trace.bound_symbols: + for o in bsym.flat_proxy_outs: + vo = variableify(o) + + if vo in all_proxies: + continue - all_rematerializable = old_saved_for_bwd - fwd_trace_args + proxy_names_to_producers[o.name] = bsym + all_proxies.add(vo) + if ProxyTag.RECOMPUTE_IN_BACKWARD in o.tags and not has_tags(bsym, {prims.OpTags.RANDOM_OP}): + all_recomputable_proxies.add(vo) remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( "recomputation_policy", @@ -3183,59 +3209,81 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr ) if remat_policy: - rematerializable = remat_policy(all_rematerializable) + rematerializable = remat_policy(old_saved_for_bwd - fwd_trace_args) else: - rematerializable = all_rematerializable - - producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args) - - required_fw_args = fwd_trace_args & old_saved_for_bwd - recomputed_tensors_from_producers = set() - for prod in producers: - for prod_arg in prod.flat_args: - prod_arg = variableify(prod_arg) - if prod_arg in fwd_trace_args: - required_fw_args.add(prod_arg) - for prod_out in prod.flat_outs: - recomputed_tensors_from_producers.add(variableify(prod_out)) - - required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers - new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd) - - new_fwd_trace = from_trace(fwd_trace) - new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() - new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1])) - new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + rematerializable = old_saved_for_bwd & all_recomputable_proxies - new_bwd_trace = from_trace(bwd_trace) - # In cases where C0 name is carried from previous trace it must be removed - # as the proxy needs to register with that specific name to follow the backward - # trace standard signature. - new_bwd_trace.names.discard("C0") - - with tracectx(new_bwd_trace): - unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward)) + if not rematerializable: + return fwd_trace, bwd_trace # Here we make sure that the signature of the backward trace is the same as the one we expect. # This part of the trace is the unpacking of the tuple passed from the forward trace, - # more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents - # used to compute the vector-Jacobian product. + # more specifically, C0 unpacks into the saved for backward tensors and C1 into the nontensors + # the cotangents used to compute the vector-Jacobian product are a separate argument. + assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE + assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward" assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[4].args[0].name == "C0" assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[5].args[0].name == "C1" + p_saved_for_backward = bwd_trace.bound_symbols[2].args[0] + p_c0 = bwd_trace.bound_symbols[4].args[0] + p_c1 = bwd_trace.bound_symbols[5].args[0] + + saved_tensors = fwd_trace_args & old_saved_for_bwd + saved_nontensors = OrderedSet(variableify(p) for p in p_c1.coll) + + new_bwd_trace = from_trace(bwd_trace) + + # args will be added from unpack_trivial + have_in_backward = saved_tensors | saved_nontensors + + def compute_proxy_from_producer(p): + vp = variableify(p) + if vp in have_in_backward: + return + if vp not in all_recomputable_proxies: + have_in_backward.add(vp) + if isinstance(p, TensorProxy): + saved_tensors.add(vp) + else: + saved_nontensors.add(vp) + return + producer_bsym = proxy_names_to_producers[p.name] + for p in producer_bsym.flat_proxy_args: + compute_proxy_from_producer(p) + for o in producer_bsym.flat_proxy_outs: + have_in_backward.add(variableify(o)) + new_bwd_trace.bound_symbols.append(producer_bsym) for idx, bsym in enumerate(bwd_trace.bound_symbols): - if idx == 4: - new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward) - new_bwd_trace.bound_symbols.append(new_unpack) - elif idx == 6: - new_bwd_trace.bound_symbols.extend(producers) + if idx in {4, 5}: + # handled later new_bwd_trace.bound_symbols.append(bsym) else: + for p in bsym.flat_proxy_args: + compute_proxy_from_producer(p) + for o in bsym.flat_proxy_outs: + have_in_backward.add(variableify(o)) new_bwd_trace.bound_symbols.append(bsym) - new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]] + new_fwd_trace = from_trace(fwd_trace) + new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() + + # TODO: fix ordering... + new_c0 = tuple(unvariableify(i) for i in saved_tensors) + new_c1 = tuple(unvariableify(i) for i in saved_nontensors) + + new_return_args = (fwd_trace.output[0], (new_c0, new_c1)) + new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + + p_saved_for_backward.coll = (new_c0, new_c1) + p_c0.coll = new_c0 + p_c1.coll = new_c1 + + new_bwd_trace.bound_symbols[4] = prims.unpack_sequence.bind(p_c0, len(new_c0), output=new_c0) + new_bwd_trace.bound_symbols[5] = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1) + new_bwd_trace.args = [(new_c0, new_c1), *bwd_trace.args[1:]] elapsed_time_ns = time.perf_counter_ns() - start_time_ns new_bwd_trace.set_provenance( diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 454582c92e..0c20f4866f 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -4,6 +4,7 @@ import torch import thunder.core.utils as utils +from thunder.core.compile_data import get_compile_option from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify from thunder.core.pytree import tree_flatten @@ -349,7 +350,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat ) bw_traces.append(bw_extrace) - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 73e3453e9a..f8744415d6 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -27,6 +27,7 @@ run_snippet, assert_closer, IN_CI, + NVFUSER_AVAILABLE, requiresCUDA, version_between, ) @@ -1886,3 +1887,31 @@ def func(x): torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual_gr, expected_gr) + + +@pytest.mark.parametrize("device", ("cuda", "cpu")) +def test_backward_recomputation_decomposed_ops(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + def fn(a): + return torch.nn.functional.gelu(a) + + jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False) + jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True) + a = torch.randn(2, 2, device=device, requires_grad=True) + res = jfn(a) + res2 = jfn2(a) + assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed + assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1 + + if NVFUSER_AVAILABLE and device == "cuda": + # check everything is fused + assert {bsym.sym.name for bsym in thunder.last_backward_traces(jfn2)[-1].bound_symbols} == { + "nvFusion0", + "clear_mutable_collection", + "python_return", + "python_del", + "unpack_sequence", + "unpack_trivial", + } From e9060093517b5e1841842d9cd4626b3806a25c10 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 7 Jan 2025 22:43:18 +0100 Subject: [PATCH 2/4] fix the function-passing and update tests --- thunder/core/transforms.py | 32 +++++++++++++++++++--------- thunder/tests/test_examine_memory.py | 8 +++---- thunder/tests/test_transforms.py | 3 ++- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 09c7ba1509..370dab3f09 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3153,10 +3153,14 @@ def backward_fn(saved_for_backward, cotangents): forward_trace.set_provenance(TraceProvenance("Augmented forward pass")) backward_trace.set_provenance(TraceProvenance("Backward pass")) + remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( + "recomputation_policy", + "A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.", + ) enable_saved_for_backward_recomputation: None | bool = get_compile_option( "enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation." ) - if enable_saved_for_backward_recomputation is None: + if enable_saved_for_backward_recomputation is None or remat_policy: enable_saved_for_backward_recomputation = True if enable_saved_for_backward_recomputation: forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace) @@ -3187,6 +3191,22 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr fwd_trace_args = OrderedSet(variableify(j) for j in fwd_trace.args) old_saved_for_bwd = OrderedSet(variableify(j) for j in saved_for_bw) + remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( + "recomputation_policy", + "A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.", + ) + + enable_saved_for_backward_recomputation: None | bool = get_compile_option( + "enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation." + ) + + if remat_policy: + for v in remat_policy(old_saved_for_bwd - fwd_trace_args): + unvariableify(v).tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) + elif enable_saved_for_backward_recomputation: + for v in old_saved_for_bwd - fwd_trace_args: + unvariableify(v).tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) + all_proxies = fwd_trace_args.copy() all_recomputable_proxies = OrderedSet() @@ -3203,15 +3223,7 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr if ProxyTag.RECOMPUTE_IN_BACKWARD in o.tags and not has_tags(bsym, {prims.OpTags.RANDOM_OP}): all_recomputable_proxies.add(vo) - remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( - "recomputation_policy", - "A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.", - ) - - if remat_policy: - rematerializable = remat_policy(old_saved_for_bwd - fwd_trace_args) - else: - rematerializable = old_saved_for_bwd & all_recomputable_proxies + rematerializable = old_saved_for_bwd & all_recomputable_proxies if not rematerializable: return fwd_trace, bwd_trace diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 51278003de..80c0b5b83c 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -113,7 +113,7 @@ def test_nanogpt_block(): # Actual memory usage may vary depending on hardware and cuBLAS settings. # We are checking the estimated memory against a fixed value for consistency. - assert max_mem_fw[0] == 381754368 - assert sum(max_mem_fw[1].values()) == 375462912 - assert max_mem_bw[0] == 437292032 - assert sum(max_mem_bw[1].values()) == 40934400 + assert max_mem_fw[0] == 262183936 + assert sum(max_mem_fw[1].values()) == 135306240 + assert max_mem_bw[0] == 484833280 + assert sum(max_mem_bw[1].values()) == 169915392 diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index 2418eeae6a..b10e0048d2 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -381,7 +381,7 @@ def forward(self, x): fwd_trace = None for trace in thunder.last_traces(jmodel): - if str(trace.get_provenance()) == "# Constructed by Augmented forward pass": + if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in trace.tags: fwd_trace = trace break @@ -433,6 +433,7 @@ def forward(self, x): filter(lambda x: isinstance(x.output, TensorProxy), new_bwd.bound_symbols[6:]), ) ) + # check that all the fwd are recomputed for rematerializable in all_rematerializable: assert rematerializable in bwd_bsym_out From 2979681dcd4b94327c5105c699be421bc78bd0f6 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 7 Jan 2025 23:32:56 +0100 Subject: [PATCH 3/4] fix fsdp test --- thunder/tests/distributed/test_fsdp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index a5716e8c9b..335ebefca8 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -116,9 +116,7 @@ def test_rematerialize_all_gather(self): x = torch.ones((2, 12), device=device) cm(x).mean().backward() - fwd_trc = [ - t for t in thunder.last_traces(cm) if getattr(t.get_provenance(), "pss", "") == "Augmented forward pass" - ][0] + fwd_trc = [t for t in thunder.last_traces(cm) if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in t.tags][0] bwd_trc = thunder.last_backward_traces(cm)[0] from thunder.core.rematerialization import rematerialize_all_gather From 0929eb4227c32439913cd02ed9e8342bb62e73f3 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 8 Jan 2025 15:42:29 +0100 Subject: [PATCH 4/4] review comments. Thank you Riccardo --- thunder/core/trace_interpreter.py | 3 +++ thunder/core/transforms.py | 1 - thunder/executors/torch_autograd.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 24e196bb0a..6bee20a165 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -185,6 +185,9 @@ def do_swap(v): # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? for o in new_bsym.flat_proxy_outs: if variableify(o) not in swap_map: + # when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward. + # Typically our decompositions are for things that will then be fused together. + # We could refine this heuristic to exclude "expensive" operations. o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) new_trace.bound_symbols.append( new_bsym.from_bsym_swap_proxies(swap_map).from_bsym( diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 370dab3f09..1ecd0d0cc3 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3282,7 +3282,6 @@ def compute_proxy_from_producer(p): new_fwd_trace = from_trace(fwd_trace) new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() - # TODO: fix ordering... new_c0 = tuple(unvariableify(i) for i in saved_tensors) new_c1 = tuple(unvariableify(i) for i in saved_nontensors) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 0c20f4866f..f44732a1a1 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -351,7 +351,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_traces.append(bw_extrace) use_rematerialization: None | bool = get_compile_option( - "use_rematerialization", "use rematerialization of parameters" + "use_forward_backward_rematerialization", "use rematerialization of saved for backward values in fusions" ) if use_rematerialization: fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)