diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 13e6305b85..50421188a8 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -541,19 +541,15 @@ def setup_distributed(self, model): return model def setup_activation_checkpointing(self): - if "thunder" in self.compile and "dynamo" not in self.compile: - # checkpointing is an option to thunder.jit - return - if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()): warnings.warn( "FSDP checkpointing is configured, but the model already contains checkpointed layers." " Checkpointing will be ignored." ) return - check_fn = lambda submodule: isinstance(submodule, Block) apply_activation_checkpointing(self.model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + print(self.model) # TODO(crcrpar): Think of apply `torch.compile` or `thunder.jit` per block/module # like https://github.com/pytorch/torchtitan/blob/cfc0f4e/torchtitan/parallelisms/parallelize_llama.py#L275-L284 @@ -890,12 +886,19 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########") print(thunder.last_backward_traces(thunder_fn)[-1]) else: + from thunder.examine.memory_calculation import get_alloc_memory + from thunder.executors.passes import del_last_used + for i, f_traces in enumerate(fwd_traces, start=1): print(f"##########\n#{i}-th ThunderModule\n##########") - print(f_traces[-1]) + print(f_traces) for i, b_traces in enumerate(bwd_traces, start=1): print(f"##########\n#{i}-th ThunderModule\n##########") - print(b_traces[-1]) + for tr in b_traces: + dltr = del_last_used(tr) + tr_peak_memory, _ = get_alloc_memory(dltr) + print(f"#the following trace uses ~{tr_peak_memory/(2**30):.2f}GB memory") + print(tr) if global_rank in [0, None]: if return_metrics_as_json: diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 726353983a..be478a778c 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -895,11 +895,18 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype): return res +ProxyTag.register_tag("RECOMPUTE_IN_BACKWARD") + + @register_general_jit_lookaside(torch.utils.checkpoint.checkpoint) def _general_jit_torch_checkpoint_lookaside( function: Callable, *args, - **kwargs: Any, + context_fn: None | Callable[..., Any] = None, + debug: None | bool = None, + determinism_check: None | str = None, + preserve_rng_state: None | bool = None, + use_reentrant: bool = False, ): """ This function does preprocessing of the `function` argument before @@ -917,17 +924,48 @@ def _general_jit_torch_checkpoint_lookaside( The result of calling `thunder.torch.checkpoint` with the preprocessed `function` and its arguments. """ - from thunder.torch import checkpoint - # It should be possible to call the general_thunder_jit here to handle the - # conversion from torch to thunder but it doesn't work now - # See https://github.com/Lightning-AI/lightning-thunder/issues/1126 - # TODO: Convert the function to a Thunder function - def thunder_function(*args, **kwargs): - return unwrap(function)(*args, **kwargs) + if unwrap(use_reentrant): + return do_raise( + "torch.checkpoint: use_reentrant=True is not supported in Thunder", + ) + # NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments + # Let's raise a warning if any of these arguments are passed + if unwrap(context_fn) is not None: + warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored") + if unwrap(debug) is not None: + warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored") + if unwrap(determinism_check) is not None: + warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored") + if unwrap(preserve_rng_state) is not None: + warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored") + + jit_ctx: JitCtx = get_jit_ctx() + jit_ctx.computation_trace.push_scope([]) + + input_output_proxy_names = set() + + def add_input_output_proxy_name(p): + if isinstance(p, Proxy): + input_output_proxy_names.add(p.name) + + tree_map(add_input_output_proxy_name, [unwrap(a) for a in args]) - wrapped_thunder_function = wrap_const(thunder_function) - return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) + res = _interpret_call(function, *args) + if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return res + + tree_map(add_input_output_proxy_name, unwrap(res)) + + new_bsyms = jit_ctx.computation_trace.pop_scope() + jit_ctx.computation_trace.bound_symbols.extend(new_bsyms) + + for bsym in new_bsyms: + for o in bsym.flat_proxy_outs: + if o.name not in input_output_proxy_names: + o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) + + return res # Adds proxy methods diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index f24ee1fad2..1a5241f237 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer( _, leaves = bsym_list_to_dag(list(new_subsymbols)) new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP) proxy_order = order_proxies(new_subsymbols) - new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name])) + new_consumer_args = tuple( + sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name]) + ) new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols) return new_consumer diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 2906241e11..74c488a479 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -16,7 +16,7 @@ import thunder.core.baseutils as baseutils import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable, Positions -from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface +from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface, TagBase from thunder.core.utils import FrozenDict, make_hashable from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map import thunder.core.dtypes as dtypes @@ -351,6 +351,10 @@ def tag_tensorproxy_output_as_detached(proxy): return result +class BoundSymbolTag(TagBase): + pass + + # A symbol, arguments (and kwarguments), output, and sub-symbols # args is a sequence of the arguments # kwargs is a dict of the kwargs @@ -377,6 +381,8 @@ class BoundSymbol(BoundSymbolInterface): source_filename: str | None = None source_positions: Positions | None = None + bsym_tags: set[BoundSymbolTag] = field(default_factory=set) + _call_ctx: None | dict[str, Any] = None _import_ctx: dict = field(default_factory=dict) @@ -412,6 +418,7 @@ def from_bsym(self, **kwargs) -> BoundSymbol: "_import_ctx": self._import_ctx, "_object_ctx": self._object_ctx, "_executor": self._executor, + "bsym_tags": self.bsym_tags.copy(), } self_kwargs.update(kwargs) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index b38cd01ac7..24e196bb0a 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -6,7 +6,7 @@ from thunder.core.trace import VariableInterface, from_trace, tracectx from thunder.core.baseutils import ProxyInterface, TensorProxyInterface from thunder.core.utils import safe_map_flat, sequencify -from thunder.core.proxies import variableify +from thunder.core.proxies import variableify, ProxyTag from thunder.core.transform_common import VJPDual @@ -183,6 +183,9 @@ def do_swap(v): for new_bsym in new_bsyms: # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? + for o in new_bsym.flat_proxy_outs: + if variableify(o) not in swap_map: + o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) new_trace.bound_symbols.append( new_bsym.from_bsym_swap_proxies(swap_map).from_bsym( source_filename=bsym.source_filename, source_positions=bsym.source_positions diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index e28821a8b5..37c4da08f6 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -38,7 +38,7 @@ from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass -from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol +from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, BoundSymbolTag, has_tags from thunder.core.trace import TraceCtx as Trace from thunder.core.trace import VariableInterface as Variable from thunder.core.trace import ( @@ -59,6 +59,7 @@ unzip2, const_as, sequencify, + OrderedSet, ProxyDict, find_producer_symbols, ) @@ -92,6 +93,7 @@ TraceTag.register_tag("AUGMENTED_FORWARD") +BoundSymbolTag.register_tag("RECOMPUTE_IN_BACKWARD") # TODO This should be a partial of thunder.trace, but that would cause a circular import @@ -2991,6 +2993,7 @@ def unpacking_fn(saved_for_backward, cotangents): ) ) backward_trace.bound_symbols = list((*unpacking_trace.bound_symbols[:-1], *backward_trace_bsyms_without_unpacking)) + backward_trace.scopes[0] = backward_trace.bound_symbols def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces: @@ -3153,6 +3156,8 @@ def backward_fn(saved_for_backward, cotangents): enable_saved_for_backward_recomputation: None | bool = get_compile_option( "enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation." ) + if enable_saved_for_backward_recomputation is None: + enable_saved_for_backward_recomputation = True if enable_saved_for_backward_recomputation: forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace) @@ -3171,11 +3176,32 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr start_time_ns = time.perf_counter_ns() + cd = get_compile_data() + have_nvfuser = any([ex.name == "nvfuser" for ex in cd.executors_list]) if cd is not None else False + if have_nvfuser: + from thunder.core.rematerialization import replace_uniform + + fwd_trace = replace_uniform(fwd_trace) + saved_for_bw = get_saved_for_backward_tensors(fwd_trace) - fwd_trace_args = {variableify(j) for j in fwd_trace.args} - old_saved_for_bwd = {variableify(j) for j in saved_for_bw} + fwd_trace_args = OrderedSet(variableify(j) for j in fwd_trace.args) + old_saved_for_bwd = OrderedSet(variableify(j) for j in saved_for_bw) + + all_proxies = fwd_trace_args.copy() + all_recomputable_proxies = OrderedSet() + + proxy_names_to_producers = {} + for bsym in fwd_trace.bound_symbols: + for o in bsym.flat_proxy_outs: + vo = variableify(o) + + if vo in all_proxies: + continue - all_rematerializable = old_saved_for_bwd - fwd_trace_args + proxy_names_to_producers[o.name] = bsym + all_proxies.add(vo) + if ProxyTag.RECOMPUTE_IN_BACKWARD in o.tags and not has_tags(bsym, {prims.OpTags.RANDOM_OP}): + all_recomputable_proxies.add(vo) remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( "recomputation_policy", @@ -3183,59 +3209,81 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr ) if remat_policy: - rematerializable = remat_policy(all_rematerializable) + rematerializable = remat_policy(old_saved_for_bwd - fwd_trace_args) else: - rematerializable = all_rematerializable - - producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args) - - required_fw_args = fwd_trace_args & old_saved_for_bwd - recomputed_tensors_from_producers = set() - for prod in producers: - for prod_arg in prod.flat_args: - prod_arg = variableify(prod_arg) - if prod_arg in fwd_trace_args: - required_fw_args.add(prod_arg) - for prod_out in prod.flat_outs: - recomputed_tensors_from_producers.add(variableify(prod_out)) - - required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers - new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd) - - new_fwd_trace = from_trace(fwd_trace) - new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() - new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1])) - new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + rematerializable = old_saved_for_bwd & all_recomputable_proxies - new_bwd_trace = from_trace(bwd_trace) - # In cases where C0 name is carried from previous trace it must be removed - # as the proxy needs to register with that specific name to follow the backward - # trace standard signature. - new_bwd_trace.names.discard("C0") - - with tracectx(new_bwd_trace): - unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward)) + if not rematerializable: + return fwd_trace, bwd_trace # Here we make sure that the signature of the backward trace is the same as the one we expect. # This part of the trace is the unpacking of the tuple passed from the forward trace, - # more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents - # used to compute the vector-Jacobian product. + # more specifically, C0 unpacks into the saved for backward tensors and C1 into the nontensors + # the cotangents used to compute the vector-Jacobian product are a separate argument. + assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE + assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward" assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[4].args[0].name == "C0" assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[5].args[0].name == "C1" + p_saved_for_backward = bwd_trace.bound_symbols[2].args[0] + p_c0 = bwd_trace.bound_symbols[4].args[0] + p_c1 = bwd_trace.bound_symbols[5].args[0] + + saved_tensors = fwd_trace_args & old_saved_for_bwd + saved_nontensors = OrderedSet(variableify(p) for p in p_c1.coll) + + new_bwd_trace = from_trace(bwd_trace) + + # args will be added from unpack_trivial + have_in_backward = saved_tensors | saved_nontensors + + def compute_proxy_from_producer(p): + vp = variableify(p) + if vp in have_in_backward: + return + if vp not in all_recomputable_proxies: + have_in_backward.add(vp) + if isinstance(p, TensorProxy): + saved_tensors.add(vp) + else: + saved_nontensors.add(vp) + return + producer_bsym = proxy_names_to_producers[p.name] + for p in producer_bsym.flat_proxy_args: + compute_proxy_from_producer(p) + for o in producer_bsym.flat_proxy_outs: + have_in_backward.add(variableify(o)) + new_bwd_trace.bound_symbols.append(producer_bsym) for idx, bsym in enumerate(bwd_trace.bound_symbols): - if idx == 4: - new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward) - new_bwd_trace.bound_symbols.append(new_unpack) - elif idx == 6: - new_bwd_trace.bound_symbols.extend(producers) + if idx in {4, 5}: + # handled later new_bwd_trace.bound_symbols.append(bsym) else: + for p in bsym.flat_proxy_args: + compute_proxy_from_producer(p) + for o in bsym.flat_proxy_outs: + have_in_backward.add(variableify(o)) new_bwd_trace.bound_symbols.append(bsym) - new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]] + new_fwd_trace = from_trace(fwd_trace) + new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() + + # TODO: fix ordering... + new_c0 = tuple(unvariableify(i) for i in saved_tensors) + new_c1 = tuple(unvariableify(i) for i in saved_nontensors) + + new_return_args = (fwd_trace.output[0], (new_c0, new_c1)) + new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + + p_saved_for_backward.coll = (new_c0, new_c1) + p_c0.coll = new_c0 + p_c1.coll = new_c1 + + new_bwd_trace.bound_symbols[4] = prims.unpack_sequence.bind(p_c0, len(new_c0), output=new_c0) + new_bwd_trace.bound_symbols[5] = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1) + new_bwd_trace.args = [(new_c0, new_c1), *bwd_trace.args[1:]] elapsed_time_ns = time.perf_counter_ns() - start_time_ns new_bwd_trace.set_provenance( diff --git a/thunder/examine/memory_calculation.py b/thunder/examine/memory_calculation.py index 14c2123ed2..c4255773c6 100644 --- a/thunder/examine/memory_calculation.py +++ b/thunder/examine/memory_calculation.py @@ -147,7 +147,7 @@ def clear_mutable_collection_argument_memory( return memory_size -def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]: +def get_alloc_memory(trc: TraceCtx, *, annotate=False) -> tuple[int, dict[str, int]]: """ Calculate the memory usage based on the executable trace. The memory calculation is based only on the compile-time trace, i.e. the input and output shape @@ -189,6 +189,10 @@ def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]: impl = partial(impl, is_argument=is_argument) allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory) + if annotate: + if bsym.header: + bsym.header += " " + bsym.header += f"mem after next op: ~{allocated/(2**30):2f}GB" max_allocated = max(max_allocated, allocated) return max_allocated, name_to_alloc_memory diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 02fa5081da..aef1c3e317 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -925,7 +925,10 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: # Some of the operations might be better placed with its consumers (for # example residual connection in transformer block). This pass moves # them to the consumer. - if self._use_rematerialization: + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization and self._use_rematerialization: fusedtrace = rematerialize(fusedtrace) fusedtrace = remove_redundant_casts(fusedtrace) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 454582c92e..0c20f4866f 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -4,6 +4,7 @@ import torch import thunder.core.utils as utils +from thunder.core.compile_data import get_compile_option from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify from thunder.core.pytree import tree_flatten @@ -349,7 +350,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat ) bw_traces.append(bw_extrace) - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 7cb0fac0e0..e3df7c9440 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -16,6 +16,7 @@ from thunder.executors.passes import ( update_fusion_call_ctx, transform_for_execution, + del_last_used, ) from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo, fuse_bound_symbols @@ -95,6 +96,7 @@ def make_compiled( region_trace._siginfo.args = [(a.name, None) for a in region_trace.args] torchex_trace = transform_for_execution(region_trace, executors_list=(pytorch_ex,)) + torchex_trace = del_last_used(torchex_trace) trace_callable = torchex_trace.python_callable(include_decorators=False) torch_compile_fullgraph: None | bool = get_compile_option( @@ -184,7 +186,11 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: fusedtrace.bound_symbols = fused_bsyms - fusedtrace = rematerialize(fusedtrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fusedtrace = rematerialize(fusedtrace) fusedtrace = dce(fusedtrace) fusedtrace = update_fusion_call_ctx(fusedtrace) diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 51278003de..f270ecc199 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -113,7 +113,7 @@ def test_nanogpt_block(): # Actual memory usage may vary depending on hardware and cuBLAS settings. # We are checking the estimated memory against a fixed value for consistency. - assert max_mem_fw[0] == 381754368 - assert sum(max_mem_fw[1].values()) == 375462912 - assert max_mem_bw[0] == 437292032 - assert sum(max_mem_bw[1].values()) == 40934400 + assert max_mem_fw[0] == 262183936 + assert sum(max_mem_fw[1].values()) == 135306240 + assert max_mem_bw[0] == 472259584 + assert sum(max_mem_bw[1].values()) == 157341696 diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 73e3453e9a..0d0a3da6da 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -27,6 +27,7 @@ run_snippet, assert_closer, IN_CI, + NVFUSER_AVAILABLE, requiresCUDA, version_between, ) @@ -1179,12 +1180,15 @@ def func(a, b, *, c): a = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True) - initial_trace = trace(inline_trace=False)(func, a, b, c=c) + jfn = thunder.jit(func) + cd, inps, _ = thunder.compile_data(jfn).get_computation_and_inputs(a, b, c=c) + initial_trace = cd.computation_traces[0] wrapped_trace = wrap_return_value_together_with_arguments(initial_trace) + fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace) fw = executor.make_callable(fw_trace) bw = executor.make_callable(bw_trace) - fw_out, saved_for_backward = fw(a, b, c=c) + fw_out, saved_for_backward = fw(*inps) initial_trace = trace()(value_and_grad(func), a, b, c=c) expected_vjp_func = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True) @@ -1194,6 +1198,7 @@ def func(a, b, *, c): output_grads = tree_map(lambda x: torch.ones_like(x), fw_out["output"]) bw_out = bw(saved_for_backward, output_grads) + expected_grads = (*expected_grads[:-1], expected_grads[-1]["c"]) torch.testing.assert_close(bw_out, expected_grads) @@ -1754,8 +1759,8 @@ def f(x, y): # The intermediate values are recomputed during backward pass. assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2 # We detach the saved tensors (which returns a new Python tensor backed by same storage) - assert out.grad_fn.next_functions[0][0].saved_tensors[0].data_ptr() == x.data_ptr() - assert out.grad_fn.next_functions[0][0].saved_tensors[1].data_ptr() == y.data_ptr() + # the order seems to be non-deterministic sometimes + assert {t.data_ptr() for t in out.grad_fn.next_functions[0][0].saved_tensors} == {x.data_ptr(), y.data_ptr()} g = torch.ones_like(out) out.backward(g) @@ -1768,6 +1773,52 @@ def f(x, y): torch.testing.assert_close(y.grad, y_ref.grad) +@requiresCUDA +def test_checkpoint_max_memory(): + import torch.utils.checkpoint + + class Checkpoint(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args): + return torch.utils.checkpoint.checkpoint(self.module, *args, use_reentrant=False) + + with torch.device("cuda:0"): + m = torch.nn.Sequential( + torch.nn.Linear(1024, 16), + torch.nn.ReLU(), + *[ + Checkpoint( + torch.nn.Sequential( + torch.nn.Linear(16, 2048), + torch.nn.Linear(2048, 16), + torch.nn.ReLU(), + ) + ) + for _ in range(10) + ], + torch.nn.Linear(16, 1024), + ) + inps = torch.randn(512, 1024, requires_grad=True) + + jm = thunder.jit(m, executors=()) # no rematerialization + res = jm(inps) + res.sum().backward() + + torch.cuda.reset_peak_memory_stats() + mem_base = torch.cuda.memory_allocated() + res = jm(inps) + res.sum().backward() + mem_max = torch.cuda.max_memory_allocated() + # without chewckpointing the peak mem about 43MB. + # With checkpointing as coded in the model and recomputation where the + # values are used, we get a little over 10MB, so we put the barrier at 16MB + mb_used = (mem_max - mem_base) / 2**20 + assert mb_used < 16 + + def test_inconsistent_output_length_grad_transform(): from thunder.extend import OperatorExecutor from thunder.core.proxies import AnyProxy, TensorProxy @@ -1886,3 +1937,31 @@ def func(x): torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual_gr, expected_gr) + + +@pytest.mark.parametrize("device", ("cuda", "cpu")) +def test_backward_recomputation_decomposed_ops(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + def fn(a): + return torch.nn.functional.gelu(a) + + jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False) + jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True) + a = torch.randn(2, 2, device=device, requires_grad=True) + res = jfn(a) + res2 = jfn2(a) + assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed + assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1 + + if NVFUSER_AVAILABLE and device == "cuda": + # check everything is fused + assert {bsym.sym.name for bsym in thunder.last_backward_traces(jfn2)[-1].bound_symbols} == { + "nvFusion0", + "clear_mutable_collection", + "python_return", + "python_del", + "unpack_sequence", + "unpack_trivial", + } diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index e1be6d86b8..80e4301398 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1433,7 +1433,7 @@ def foo(a): a = torch.reshape(a, [a.numel()]) return a.relu() - jfoo = thunder.jit(foo, cache="symbolic values") + jfoo = thunder.jit(foo, cache="symbolic values", enable_saved_for_backward_recomputation=False) a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 372d2cfe2e..fab3e500e2 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -553,4 +553,4 @@ def test_hf_llama(): top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols} # changes this to fewer as needed, the goal is to not have too many fusions - assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7 + assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 8 diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index f1218f91dd..d0121a07e4 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -67,7 +67,7 @@ def test_find_producer_symbols(executor, device, _): # We will try to find a subgraph for rematerializing __c and __d t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -117,7 +117,7 @@ def test_find_producer_symbols(executor, device, _): def test_apply_rematerialization_producer(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -151,7 +151,7 @@ def test_apply_rematerialization_producer(executor, device, _): def test_apply_rematerialization_consumer(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -217,7 +217,7 @@ def foo(t0): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(foo, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -257,7 +257,7 @@ def func(t0): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -309,7 +309,7 @@ def func(t0): return t3, t4 t0 = make_tensor(2, 2, dtype=torch.float32, device=device) - compiled_func = executor.make_callable(func) + compiled_func = executor.make_callable(func, enable_saved_for_backward_recomputation=False) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[0] @@ -369,7 +369,12 @@ def func( from thunder.executors.torch_compile import torch_compile_cat_ex try: - compiled_func = thunder.jit(func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor)) + compiled_func = thunder.jit( + func, + executors=(torch_compile_cat_ex, thunder.nvfuser_executor), + use_rematerialization=True, + enable_saved_for_backward_recomputation=False, + ) _ = compiled_func( t0, t1, @@ -391,7 +396,7 @@ def func( def test_find_cut(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) intial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(intial_trace.python_callable()) + compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -418,7 +423,7 @@ def test_find_cut_dropout(executor, device, _): with patch("thunder.core.rematerialization.replace_uniform", new=replace_uniform_mock): intial_trace = thunder.trace()(func_with_dropout, t0) - compiled_func = thunder.jit(intial_trace.python_callable()) + compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -459,10 +464,12 @@ def func(t0): # Result with rematerialization and without rematerialization should match initial_trace = thunder.trace()(func, t0) - result_with_remat = thunder.jit(initial_trace.python_callable())(t0) + result_with_remat = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)(t0) assert not isinstance(result_with_remat, Exception) - result_without_remat = disable_rematerialization_in_nvfuser_fusion(thunder.jit(initial_trace.python_callable()))(t0) + result_without_remat = disable_rematerialization_in_nvfuser_fusion( + thunder.jit(initial_trace.python_callable(), use_rematerialization=True) + )(t0) torch.testing.assert_close(result_with_remat, result_without_remat) @@ -474,7 +481,7 @@ def test_rematerialization_name_collision(): def forward(x): return x.softmax(dim=1, dtype=torch.float) - jforward = thunder.jit(forward) + jforward = thunder.jit(forward, use_rematerialization=True) x = torch.randn([32768, 8], dtype=torch.bfloat16, device="cuda", requires_grad=True) @@ -515,7 +522,7 @@ def forward(self, x): # At the time of writing, linear and matmul are not fused into nvFuser # regions by default therefore, we should enable them separately - jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True) + jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True, use_rematerialization=True) jmodel(inp) def assert_subsymbol_count(trace: TraceCtx, /, num_linears: int, num_matmuls: int, num_fusions: int): diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index 2418eeae6a..45f118e7af 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -375,7 +375,7 @@ def forward(self, x): model = MyModel() # Do not recompute anything - jmodel = thunder.jit(model) + jmodel = thunder.jit(model, enable_saved_for_backward_recomputation=False) jmodel(a) fwd_trace = None diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 38d0353c1f..018ca076b8 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -63,7 +63,6 @@ # NOTE torch is a requirement import torch -import torch.utils.checkpoint import torch._higher_order_ops.wrap import warnings @@ -5325,7 +5324,6 @@ def _unwrap_if_dead(tensor): @torchsymbol( - torch.utils.checkpoint.checkpoint, torch.ops.higher_order.tag_activation_checkpoint, id="activation_checkpoint", )