Skip to content

update checkpointing support for jit #1560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
17 changes: 10 additions & 7 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
@@ -541,19 +541,15 @@ def setup_distributed(self, model):
return model

def setup_activation_checkpointing(self):
if "thunder" in self.compile and "dynamo" not in self.compile:
# checkpointing is an option to thunder.jit
return

if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()):
warnings.warn(
"FSDP checkpointing is configured, but the model already contains checkpointed layers."
" Checkpointing will be ignored."
)
return

check_fn = lambda submodule: isinstance(submodule, Block)
apply_activation_checkpointing(self.model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
print(self.model)

# TODO(crcrpar): Think of apply `torch.compile` or `thunder.jit` per block/module
# like https://github.com/pytorch/torchtitan/blob/cfc0f4e/torchtitan/parallelisms/parallelize_llama.py#L275-L284
@@ -890,12 +886,19 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########")
print(thunder.last_backward_traces(thunder_fn)[-1])
else:
from thunder.examine.memory_calculation import get_alloc_memory
from thunder.executors.passes import del_last_used

for i, f_traces in enumerate(fwd_traces, start=1):
print(f"##########\n#{i}-th ThunderModule\n##########")
print(f_traces[-1])
print(f_traces)
for i, b_traces in enumerate(bwd_traces, start=1):
print(f"##########\n#{i}-th ThunderModule\n##########")
print(b_traces[-1])
for tr in b_traces:
dltr = del_last_used(tr)
tr_peak_memory, _ = get_alloc_memory(dltr)
print(f"#the following trace uses ~{tr_peak_memory/(2**30):.2f}GB memory")
print(tr)

if global_rank in [0, None]:
if return_metrics_as_json:
58 changes: 48 additions & 10 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -895,11 +895,18 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
return res


ProxyTag.register_tag("RECOMPUTE_IN_BACKWARD")


@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint)
def _general_jit_torch_checkpoint_lookaside(
function: Callable,
*args,
**kwargs: Any,
context_fn: None | Callable[..., Any] = None,
debug: None | bool = None,
determinism_check: None | str = None,
preserve_rng_state: None | bool = None,
use_reentrant: bool = False,
):
"""
This function does preprocessing of the `function` argument before
@@ -917,17 +924,48 @@ def _general_jit_torch_checkpoint_lookaside(
The result of calling `thunder.torch.checkpoint` with the preprocessed
`function` and its arguments.
"""
from thunder.torch import checkpoint

# It should be possible to call the general_thunder_jit here to handle the
# conversion from torch to thunder but it doesn't work now
# See https://github.com/Lightning-AI/lightning-thunder/issues/1126
# TODO: Convert the function to a Thunder function
def thunder_function(*args, **kwargs):
return unwrap(function)(*args, **kwargs)
if unwrap(use_reentrant):
return do_raise(
"torch.checkpoint: use_reentrant=True is not supported in Thunder",
)
# NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments
# Let's raise a warning if any of these arguments are passed
if unwrap(context_fn) is not None:
warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored")
if unwrap(debug) is not None:
warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored")
if unwrap(determinism_check) is not None:
warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored")
if unwrap(preserve_rng_state) is not None:
warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored")

jit_ctx: JitCtx = get_jit_ctx()
jit_ctx.computation_trace.push_scope([])

input_output_proxy_names = set()

def add_input_output_proxy_name(p):
if isinstance(p, Proxy):
input_output_proxy_names.add(p.name)

tree_map(add_input_output_proxy_name, [unwrap(a) for a in args])

wrapped_thunder_function = wrap_const(thunder_function)
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)
res = _interpret_call(function, *args)
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return res

tree_map(add_input_output_proxy_name, unwrap(res))

new_bsyms = jit_ctx.computation_trace.pop_scope()
jit_ctx.computation_trace.bound_symbols.extend(new_bsyms)

for bsym in new_bsyms:
for o in bsym.flat_proxy_outs:
if o.name not in input_output_proxy_names:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

return res


# Adds proxy methods
4 changes: 3 additions & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
@@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer(
_, leaves = bsym_list_to_dag(list(new_subsymbols))
new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP)
proxy_order = order_proxies(new_subsymbols)
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
new_consumer_args = tuple(
sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name])
)
new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols)
return new_consumer

9 changes: 8 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
import thunder.core.baseutils as baseutils
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable, Positions
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface, TagBase
from thunder.core.utils import FrozenDict, make_hashable
from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map
import thunder.core.dtypes as dtypes
@@ -351,6 +351,10 @@ def tag_tensorproxy_output_as_detached(proxy):
return result


class BoundSymbolTag(TagBase):
pass


# A symbol, arguments (and kwarguments), output, and sub-symbols
# args is a sequence of the arguments
# kwargs is a dict of the kwargs
@@ -377,6 +381,8 @@ class BoundSymbol(BoundSymbolInterface):
source_filename: str | None = None
source_positions: Positions | None = None

bsym_tags: set[BoundSymbolTag] = field(default_factory=set)

_call_ctx: None | dict[str, Any] = None

_import_ctx: dict = field(default_factory=dict)
@@ -412,6 +418,7 @@ def from_bsym(self, **kwargs) -> BoundSymbol:
"_import_ctx": self._import_ctx,
"_object_ctx": self._object_ctx,
"_executor": self._executor,
"bsym_tags": self.bsym_tags.copy(),
}

self_kwargs.update(kwargs)
5 changes: 4 additions & 1 deletion thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
from thunder.core.trace import VariableInterface, from_trace, tracectx
from thunder.core.baseutils import ProxyInterface, TensorProxyInterface
from thunder.core.utils import safe_map_flat, sequencify
from thunder.core.proxies import variableify
from thunder.core.proxies import variableify, ProxyTag
from thunder.core.transform_common import VJPDual


@@ -183,6 +183,9 @@ def do_swap(v):

for new_bsym in new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
132 changes: 90 additions & 42 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, BoundSymbolTag, has_tags
from thunder.core.trace import TraceCtx as Trace
from thunder.core.trace import VariableInterface as Variable
from thunder.core.trace import (
@@ -59,6 +59,7 @@
unzip2,
const_as,
sequencify,
OrderedSet,
ProxyDict,
find_producer_symbols,
)
@@ -92,6 +93,7 @@


TraceTag.register_tag("AUGMENTED_FORWARD")
BoundSymbolTag.register_tag("RECOMPUTE_IN_BACKWARD")


# TODO This should be a partial of thunder.trace, but that would cause a circular import
@@ -2991,6 +2993,7 @@ def unpacking_fn(saved_for_backward, cotangents):
)
)
backward_trace.bound_symbols = list((*unpacking_trace.bound_symbols[:-1], *backward_trace_bsyms_without_unpacking))
backward_trace.scopes[0] = backward_trace.bound_symbols


def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces:
@@ -3153,6 +3156,8 @@ def backward_fn(saved_for_backward, cotangents):
enable_saved_for_backward_recomputation: None | bool = get_compile_option(
"enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation."
)
if enable_saved_for_backward_recomputation is None:
enable_saved_for_backward_recomputation = True
if enable_saved_for_backward_recomputation:
forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace)

@@ -3171,71 +3176,114 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr

start_time_ns = time.perf_counter_ns()

cd = get_compile_data()
have_nvfuser = any([ex.name == "nvfuser" for ex in cd.executors_list]) if cd is not None else False
if have_nvfuser:
from thunder.core.rematerialization import replace_uniform

fwd_trace = replace_uniform(fwd_trace)

saved_for_bw = get_saved_for_backward_tensors(fwd_trace)
fwd_trace_args = {variableify(j) for j in fwd_trace.args}
old_saved_for_bwd = {variableify(j) for j in saved_for_bw}
fwd_trace_args = OrderedSet(variableify(j) for j in fwd_trace.args)
old_saved_for_bwd = OrderedSet(variableify(j) for j in saved_for_bw)

all_proxies = fwd_trace_args.copy()
all_recomputable_proxies = OrderedSet()

proxy_names_to_producers = {}
for bsym in fwd_trace.bound_symbols:
for o in bsym.flat_proxy_outs:
vo = variableify(o)

if vo in all_proxies:
continue

all_rematerializable = old_saved_for_bwd - fwd_trace_args
proxy_names_to_producers[o.name] = bsym
all_proxies.add(vo)
if ProxyTag.RECOMPUTE_IN_BACKWARD in o.tags and not has_tags(bsym, {prims.OpTags.RANDOM_OP}):
all_recomputable_proxies.add(vo)

remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option(
"recomputation_policy",
"A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.",
)

if remat_policy:
rematerializable = remat_policy(all_rematerializable)
rematerializable = remat_policy(old_saved_for_bwd - fwd_trace_args)
else:
rematerializable = all_rematerializable

producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args)

required_fw_args = fwd_trace_args & old_saved_for_bwd
recomputed_tensors_from_producers = set()
for prod in producers:
for prod_arg in prod.flat_args:
prod_arg = variableify(prod_arg)
if prod_arg in fwd_trace_args:
required_fw_args.add(prod_arg)
for prod_out in prod.flat_outs:
recomputed_tensors_from_producers.add(variableify(prod_out))

required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers
new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd)

new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()
new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1]))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None)
rematerializable = old_saved_for_bwd & all_recomputable_proxies

new_bwd_trace = from_trace(bwd_trace)
# In cases where C0 name is carried from previous trace it must be removed
# as the proxy needs to register with that specific name to follow the backward
# trace standard signature.
new_bwd_trace.names.discard("C0")

with tracectx(new_bwd_trace):
unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward))
if not rematerializable:
return fwd_trace, bwd_trace

# Here we make sure that the signature of the backward trace is the same as the one we expect.
# This part of the trace is the unpacking of the tuple passed from the forward trace,
# more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents
# used to compute the vector-Jacobian product.
# more specifically, C0 unpacks into the saved for backward tensors and C1 into the nontensors
# the cotangents used to compute the vector-Jacobian product are a separate argument.
assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward"
assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[4].args[0].name == "C0"
assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[5].args[0].name == "C1"
p_saved_for_backward = bwd_trace.bound_symbols[2].args[0]
p_c0 = bwd_trace.bound_symbols[4].args[0]
p_c1 = bwd_trace.bound_symbols[5].args[0]

saved_tensors = fwd_trace_args & old_saved_for_bwd
saved_nontensors = OrderedSet(variableify(p) for p in p_c1.coll)

new_bwd_trace = from_trace(bwd_trace)

# args will be added from unpack_trivial
have_in_backward = saved_tensors | saved_nontensors

def compute_proxy_from_producer(p):
vp = variableify(p)
if vp in have_in_backward:
return
if vp not in all_recomputable_proxies:
have_in_backward.add(vp)
if isinstance(p, TensorProxy):
saved_tensors.add(vp)
else:
saved_nontensors.add(vp)
return
producer_bsym = proxy_names_to_producers[p.name]
for p in producer_bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in producer_bsym.flat_proxy_outs:
have_in_backward.add(variableify(o))
new_bwd_trace.bound_symbols.append(producer_bsym)

for idx, bsym in enumerate(bwd_trace.bound_symbols):
if idx == 4:
new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward)
new_bwd_trace.bound_symbols.append(new_unpack)
elif idx == 6:
new_bwd_trace.bound_symbols.extend(producers)
if idx in {4, 5}:
# handled later
new_bwd_trace.bound_symbols.append(bsym)
else:
for p in bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in bsym.flat_proxy_outs:
have_in_backward.add(variableify(o))
new_bwd_trace.bound_symbols.append(bsym)

new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]]
new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()

# TODO: fix ordering...
new_c0 = tuple(unvariableify(i) for i in saved_tensors)
new_c1 = tuple(unvariableify(i) for i in saved_nontensors)

new_return_args = (fwd_trace.output[0], (new_c0, new_c1))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None)

p_saved_for_backward.coll = (new_c0, new_c1)
p_c0.coll = new_c0
p_c1.coll = new_c1

new_bwd_trace.bound_symbols[4] = prims.unpack_sequence.bind(p_c0, len(new_c0), output=new_c0)
new_bwd_trace.bound_symbols[5] = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1)
new_bwd_trace.args = [(new_c0, new_c1), *bwd_trace.args[1:]]

elapsed_time_ns = time.perf_counter_ns() - start_time_ns
new_bwd_trace.set_provenance(
Loading