Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update checkpointing support for jit #1560

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion notebooks/liger_kernel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@
"\n",
"jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))\n",
"res = jm(inp, inp_pos)\n",
"ref = m(inp, inp_pos)\n",
"\n",
"go = torch.randn_like(res)\n",
"(grad_res,) = torch.autograd.grad(res, jm.get_parameter(\"transformer.wte.weight\"), go)\n",
"ref = m(inp, inp_pos)\n",
"(grad_ref,) = torch.autograd.grad(ref, m.get_parameter(\"transformer.wte.weight\"), go)\n",
"\n",
"assert_close(res, ref)\n",
Expand Down
7 changes: 5 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction1, ThunderFunction2

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -751,14 +751,17 @@ def maybe_connect_to_autograd(cache_entry, result):
# resulting tensors to PyTorch's Autograd graph using the
# ThunderFunction (which is a torch.autograd.Function subclass)
data_for_autograd, (saved_tensors, saved_other) = result
ThunderFunction.apply(
side_channel = {}
dummy_res = ThunderFunction1.apply(
cache_entry.return_none_instead_of_grads,
cache_entry.backward_fn,
side_channel,
saved_tensors,
saved_other,
data_for_autograd["flat_output"],
*data_for_autograd["flat_args"],
)
ThunderFunction2.apply(dummy_res, side_channel)
result = data_for_autograd["output"]

return result
Expand Down
17 changes: 10 additions & 7 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,19 +537,15 @@ def setup_distributed(self, model):
return model

def setup_activation_checkpointing(self):
if "thunder" in self.compile and "dynamo" not in self.compile:
# checkpointing is an option to thunder.jit
return

if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()):
warnings.warn(
"FSDP checkpointing is configured, but the model already contains checkpointed layers."
" Checkpointing will be ignored."
)
return

check_fn = lambda submodule: isinstance(submodule, Block)
apply_activation_checkpointing(self.model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
print(self.model)

# TODO(crcrpar): Think of apply `torch.compile` or `thunder.jit` per block/module
# like https://github.com/pytorch/torchtitan/blob/cfc0f4e/torchtitan/parallelisms/parallelize_llama.py#L275-L284
Expand Down Expand Up @@ -886,12 +882,19 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########")
print(thunder.last_backward_traces(thunder_fn)[-1])
else:
from thunder.examine.memory_calculation import get_alloc_memory
from thunder.executors.passes import del_last_used

for i, f_traces in enumerate(fwd_traces, start=1):
print(f"##########\n#{i}-th ThunderModule\n##########")
print(f_traces[-1])
print(f_traces)
for i, b_traces in enumerate(bwd_traces, start=1):
print(f"##########\n#{i}-th ThunderModule\n##########")
print(b_traces[-1])
for tr in b_traces:
dltr = del_last_used(tr)
tr_peak_memory, _ = get_alloc_memory(dltr)
print(f"#the following trace uses ~{tr_peak_memory/(2**30):.2f}GB memory")
print(tr)

if global_rank in [0, None]:
if return_metrics_as_json:
Expand Down
58 changes: 48 additions & 10 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,11 +895,18 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
return res


ProxyTag.register_tag("RECOMPUTE_IN_BACKWARD")


@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint)
def _general_jit_torch_checkpoint_lookaside(
function: Callable,
*args,
**kwargs: Any,
context_fn: None | Callable[..., Any] = None,
debug: None | bool = None,
determinism_check: None | str = None,
preserve_rng_state: None | bool = None,
use_reentrant: bool = False,
):
"""
This function does preprocessing of the `function` argument before
Expand All @@ -917,17 +924,48 @@ def _general_jit_torch_checkpoint_lookaside(
The result of calling `thunder.torch.checkpoint` with the preprocessed
`function` and its arguments.
"""
from thunder.torch import checkpoint

# It should be possible to call the general_thunder_jit here to handle the
# conversion from torch to thunder but it doesn't work now
# See https://github.com/Lightning-AI/lightning-thunder/issues/1126
# TODO: Convert the function to a Thunder function
def thunder_function(*args, **kwargs):
return unwrap(function)(*args, **kwargs)
if unwrap(use_reentrant):
return do_raise(
"torch.checkpoint: use_reentrant=True is not supported in Thunder",
)
# NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments
# Let's raise a warning if any of these arguments are passed
if unwrap(context_fn) is not None:
warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored")
if unwrap(debug) is not None:
warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored")
if unwrap(determinism_check) is not None:
warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored")
if unwrap(preserve_rng_state) is not None:
warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored")

jit_ctx: JitCtx = get_jit_ctx()
jit_ctx.computation_trace.push_scope([])

input_output_proxy_names = set()

def add_input_output_proxy_name(p):
if isinstance(p, Proxy):
input_output_proxy_names.add(p.name)

tree_map(add_input_output_proxy_name, [unwrap(a) for a in args])

wrapped_thunder_function = wrap_const(thunder_function)
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)
res = _interpret_call(function, *args)
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return res

tree_map(add_input_output_proxy_name, unwrap(res))

new_bsyms = jit_ctx.computation_trace.pop_scope()
jit_ctx.computation_trace.bound_symbols.extend(new_bsyms)

for bsym in new_bsyms:
for o in bsym.flat_proxy_outs:
if o.name not in input_output_proxy_names:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

return res


# Adds proxy methods
Expand Down
1 change: 1 addition & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,7 @@ def _get_and_update_rng_state_meta(
PrimIDs.GET_AND_UPDATE_RNG_STATE,
"get_and_update_rng_state",
meta=_get_and_update_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)


Expand Down
4 changes: 3 additions & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer(
_, leaves = bsym_list_to_dag(list(new_subsymbols))
new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP)
proxy_order = order_proxies(new_subsymbols)
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
new_consumer_args = tuple(
sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name])
)
new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols)
return new_consumer

Expand Down
9 changes: 8 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import thunder.core.baseutils as baseutils
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable, Positions
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface, TagBase
from thunder.core.utils import FrozenDict, make_hashable
from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map
import thunder.core.dtypes as dtypes
Expand Down Expand Up @@ -351,6 +351,10 @@ def tag_tensorproxy_output_as_detached(proxy):
return result


class BoundSymbolTag(TagBase):
pass


# A symbol, arguments (and kwarguments), output, and sub-symbols
# args is a sequence of the arguments
# kwargs is a dict of the kwargs
Expand All @@ -377,6 +381,8 @@ class BoundSymbol(BoundSymbolInterface):
source_filename: str | None = None
source_positions: Positions | None = None

bsym_tags: set[BoundSymbolTag] = field(default_factory=set)

_call_ctx: None | dict[str, Any] = None

_import_ctx: dict = field(default_factory=dict)
Expand Down Expand Up @@ -412,6 +418,7 @@ def from_bsym(self, **kwargs) -> BoundSymbol:
"_import_ctx": self._import_ctx,
"_object_ctx": self._object_ctx,
"_executor": self._executor,
"bsym_tags": self.bsym_tags.copy(),
}

self_kwargs.update(kwargs)
Expand Down
16 changes: 14 additions & 2 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(self, fn: None | Callable = None, *, prologue: TraceCtx | None = No
self.args = None
self.kwargs = {}

self.bound_symbols: list[BoundSymbolInterface] = []
self.scopes = [self.bound_symbols]
self._bound_symbols: list[BoundSymbolInterface] = []
self.scopes = [self._bound_symbols]

self.name_ctr = 0
self.obj_name_ctr = 0
Expand Down Expand Up @@ -130,6 +130,16 @@ def __init__(self, fn: None | Callable = None, *, prologue: TraceCtx | None = No
# We only want the forward function to be called with ctx manager.
self._include_te_fp8_autocast = False

@property
def bound_symbols(self) -> list[BoundSymbolInterface]:
return self._bound_symbols

@bound_symbols.setter
def bound_symbols(self, bsyms: list[BoundSymbolInterface]):
assert self.scopes[0] is self._bound_symbols
self._bound_symbols = bsyms
self.scopes[0] = bsyms

@property
def tags(self):
return self._tags
Expand Down Expand Up @@ -247,8 +257,10 @@ def add_bound_symbol(self, bsym: BoundSymbolInterface) -> None:

def push_scope(self, scope: list) -> None:
self.scopes.append(scope)
assert self.scopes[0] is self.bound_symbols

def pop_scope(self) -> list:
assert self.scopes[0] is self.bound_symbols
return self.scopes.pop()

def peek_scope(self) -> list | None:
Expand Down
5 changes: 4 additions & 1 deletion thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from thunder.core.trace import VariableInterface, from_trace, tracectx
from thunder.core.baseutils import ProxyInterface, TensorProxyInterface
from thunder.core.utils import safe_map_flat, sequencify
from thunder.core.proxies import variableify
from thunder.core.proxies import variableify, ProxyTag
from thunder.core.transform_common import VJPDual


Expand Down Expand Up @@ -183,6 +183,9 @@ def do_swap(v):

for new_bsym in new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
Expand Down
Loading
Loading