Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Draft checkpoint interpret call #1275

Closed
wants to merge 10 commits into from
226 changes: 226 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)

import torch
import torch.utils.checkpoint
from thunder.core.proxies import (
DistParallelType,
proxy,
Expand Down Expand Up @@ -617,6 +618,231 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
return res


from torch.utils.checkpoint import noop_context_fn


@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint)
@register_general_jit_lookaside(torch.ops.higher_order.tag_activation_checkpoint)
def _general_jit_torch_checkpoint_lookaside(
function: Callable,
*args,
# use_reentrant=None,
# context_fn=noop_context_fn,
# determinism_check="default",
# debug=False,
**kwargs: Any,
):
"""
This function does preprocessing of the `function` argument before
dispatching the call to `thunder.torch.checkpoint`. This is necessary
because the `function` is potentially calling into PyTorch functions that
are not yet translated to Thunder. `thunder.torch.checkpoint` is a Thunder
function that can handle only Thunder functions as input.

Args:
function: The function to be checkpointed.
args: Arguments to the function.
kwargs: Keyword arguments to the function.

Returns:
The result of calling `thunder.torch.checkpoint` with the preprocessed
`function` and its arguments.
"""
# from thunder.torch import checkpoint
from thunder.core.baseutils import check, sequencify
from thunder.core.transforms import augmented_forward_impls, backward_impls, VJPDual

jit_ctx: JitCtx = get_jit_ctx()

# Construct computation trace(trace_of_checkpoint), checkpoint_fwd_sym
jit_ctx.computation_trace.push_scope([])
func = unwrap(function)
result = _interpret_call(func, *args, **kwargs)

if result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return result

bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()
unwrapped_result = unwrap(result)
trace_of_checkpoint = TraceCtx()
for bsym in bsyms:
trace_of_checkpoint.add_bound_symbol(bsym)
with tracectx(trace_of_checkpoint):
prims.python_return(unwrapped_result)

unwrapped_args = tree_map(lambda a: unwrap(a), args)

si = SigInfo("activation_checkpoint")
si.args.append(("function", None))
for a in unwrapped_args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, None))
trace_of_checkpoint._siginfo = si
trace_of_checkpoint.args = (func, *unwrapped_args)

@wraps(trace_of_checkpoint.python_callable())
def core_of_forward(f, *args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_checkpoint, f, *args, **kwargs)
Comment on lines +686 to +688
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this core_of_forward (but without f argument) be sent as the function-to-be-checkpointed to thunder.torch.checkpoint function from #1127? Then the rest of the custom registration could probably be dropped.


def bind_postprocess(bsym):
bsym._call_ctx = {}

checkpoint_fwd_sym = Symbol(
name="activation_checkpoint",
id="activation_checkpoint",
meta=core_of_forward,
_bind_postprocess=bind_postprocess,
)
# checkpoint_fwd_sym = jit_ctx.ad_hoc_executor.register_operator(
# "activation_checkpoint",
# like=core_of_forward,
# bind_postprocess=bind_postprocess,
# )
unwrapped_forward_result = checkpoint_fwd_sym(func, *unwrapped_args)
# return value
forward_result = wrap(
unwrapped_forward_result,
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[function.provenance, result.provenance]),
)
# jit_ctx.ad_hoc_executor.register_implementation(checkpoint_fwd_sym, execution_transform=core_of_forward)
thunder.executors.torchex._register_implementation(
checkpoint_fwd_sym, core_of_forward, checker=thunder.executors.torchex._always_executable
)

# construct checkpoint augmented forward(trace_of_augmented_fwd), augmented forward meta function
augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(unwrapped_result)),
((func, *sequencify(unwrapped_args)), {}),
)
trace_of_augmented_fwd = TraceCtx()
for bsym in bsyms:
trace_of_augmented_fwd.add_bound_symbol(bsym)
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
si = SigInfo(checkpoint_fwd_sym.name)
si.args.append(("function", None))
for a in unwrapped_args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, None))
trace_of_augmented_fwd._siginfo = si
# TODO: support kwargs
trace_of_augmented_fwd.args = (func, *unwrapped_args)

@wraps(trace_of_augmented_fwd.python_callable())
def core_of_augmented_forward(f, *args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, f, *args, **kwargs)

@wraps(core_of_augmented_forward)
def augmented_custom_forward_rule(f, *args, **kwargs):
primal, residulas = core_of_augmented_forward(f, *args, **kwargs)
# import pdb;pdb.set_trace()
return VJPDual(primal=primal, residuals=residulas)

augmented_forward_impls[checkpoint_fwd_sym.name] = augmented_custom_forward_rule

# construct backward, has problem
from thunder.core.transforms import vjp

def checkpoint_backward(
function,
args,
kwargs,
*grad_outputs,
):
result, grad = vjp(function)(args, grad_outputs, **kwargs)
return grad # result

grads = tree_map(
lambda a: a.replace_name(f"grad_{a.name}"),
sequencify(unwrapped_forward_result),
)
trace_of_backward = TraceCtx()
bwd_si = SigInfo(f"{checkpoint_fwd_sym.name}_backward")
bwd_si.args.append(("function", None))
for a in unwrapped_args + grads:
if isinstance(a, Proxy):
bwd_si.args.append((a.name, None))
else:
pa = proxy(a)
bwd_si.args.append((pa.name, None))
trace_of_backward._siginfo = bwd_si
trace_of_backward.args = (func, *(unwrapped_args + grads))

jit_ctx.computation_trace.push_scope([])
wrapped_grads = tree_map(lambda g: wrap(g, provenance=result.provenance), grads)

# pr1 = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[v.provenance for v in args]) # other inst?
# pr2 = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[v.provenance for v in wrapped_grads])
# res = _interpret_call(func, *args, **kwargs)
tmp = vjp(checkpoint_fwd_sym)
import pdb

pdb.set_trace()
# TODO How to call vjp on forward symbol, currently:
# TypeError: decomposed_fn_backward_rule(decomposed_fn, args, kwargs, saved_for_backward, *grads) doesn't match the signature of
# result = backward(*residuals, *cotangents) when activation_checkpoint augforward symbol is registered but backward is not and use decomposition
# residules are saved_for_backward here??
checkpoint_backward_result = tmp((func, *unwrapped_args), grads, **kwargs)

if checkpoint_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return checkpoint_backward_result

checkpoint_bwd_sym: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()

for bsym in checkpoint_bwd_sym:
trace_of_backward.add_bound_symbol(bsym)
with tracectx(trace_of_backward):
prims.python_return.bind(*unwrap(checkpoint_backward_result)[1], output=())

# @wraps(trace_of_backward.python_callable())
# def bwd_trace_callable_interface(f, *args, **kwargs):
# return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, f, *args, **kwargs)

# bwd_si = SigInfo("backward_impl")
# bwd_si.args.append(("function", None))
# for a in unwrapped_args + grads:
# if isinstance(a, Proxy):
# bwd_si.args.append((a.name, None))
# else:
# pa = proxy(a)
# bwd_si.args.append((pa.name, None))
# bwd_trace_impl = TraceCtx()
# for bsym in checkpoint_bwd_sym:
# bwd_trace_impl.add_bound_symbol(bsym)
# bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*unwrap(checkpoint_backward_result)[1], output=()))
# bwd_trace_impl._siginfo = bwd_si
# bwd_trace_impl.args = tuple(func + unwrapped_args + grads)

# @wraps(bwd_trace_impl.python_callable())
# def bwd_impl_callable(f, *args, **kwargs):
# return thunder.core.trace_interpreter.interpret_trace(bwd_trace_impl, f, *args, **kwargs)

# @wraps(bwd_trace_callable_interface)
# def backward_impl(f, *args, **kwargs):
# # check(not kwargs, lambda: f"{kwargs} expected to be empty")
# # new_args = ctx_proxy.saved_consts + args
# return bwd_impl_callable(f, *args, **kwargs)

# backward_impls[checkpoint_fwd_sym.name] = backward_impl
return forward_result

# It should be possible to call the general_thunder_jit here to handle the
# conversion from torch to thunder but it doesn't work now
# See https://github.com/Lightning-AI/lightning-thunder/issues/1126
# TODO: Convert the function to a Thunder function
# def thunder_function(*args, **kwargs):
# return unwrap(function)(*args, **kwargs)

# wrapped_thunder_function = wrap_const(thunder_function)
# return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.baseutils import ProxyInterface
from types import FunctionType

OPTREE_NAMESPACE = "thunder"

Expand All @@ -24,6 +25,7 @@ def tree_flatten(args, namespace=""):
if (
type(args)
not in {
FunctionType,
dict,
list,
str,
Expand Down
43 changes: 43 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,49 @@ def func(a, b):
get_saved_for_backward_tensors(execution_trace)


def test_torch_checkpoint():
import torch.utils.checkpoint
import torch._higher_order_ops.wrap
from thunder.dynamo import ThunderCompiler

def fn_to_checkpoint(x):
# return x.sin()#.cos().exp()
return torch.sin(x)

checkpoint_fns = (
# torch.utils.checkpoint.checkpoint,
# thunder.torch.checkpoint,
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False),
torch.ops.higher_order.tag_activation_checkpoint,
)

for checkpoint_fn in checkpoint_fns:

def f(x):
return checkpoint_fn(fn_to_checkpoint, x, use_reentrant=False)

x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True)
# backend = ThunderCompiler()
# jf = torch.compile(backend=backend)(f)
jf = thunder.jit(f)
out = jf(x)
print(thunder.last_traces(jf)[0])
print(thunder.last_backward_traces(jf)[0])

# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.saved_tensors) == 1
assert out.grad_fn.saved_tensors[0] is x

g = torch.ones_like(out)
out.backward(g)

x_ref = x.detach().requires_grad_()
out_ref = fn_to_checkpoint(x_ref)
out_ref.backward(g)
torch.testing.assert_close(x.grad, x_ref.grad)


def test_inconsistent_output_length_grad_transform():
from thunder.extend import OperatorExecutor
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down
69 changes: 68 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol
from thunder.core.transforms import register_grad
from thunder.core.transforms import register_grad, register_augmented_forward, register_backward
from thunder.core.prims import get_grad, put_grad
from thunder.core.baseutils import run_once
import thunder
Expand All @@ -56,6 +56,8 @@

# NOTE torch is a requirement
import torch
import torch.utils.checkpoint
import torch._higher_order_ops.wrap

import warnings

Expand Down Expand Up @@ -5152,6 +5154,71 @@ def _unwrap_if_dead(tensor):
register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead)


# @torchsymbol(
# torch.utils.checkpoint.checkpoint,
# torch.ops.higher_order.tag_activation_checkpoint,
# id="activation_checkpoint",
# )
# def checkpoint(
# function: Callable[..., TensorLike],
# *args: TensorLike,
# context_fn: None | Callable[..., Any] = None,
# debug: None | bool = None,
# determinism_check: None | str = None,
# preserve_rng_state: None | bool = None,
# use_reentrant: bool = False,
# **kwargs: Any,
# ) -> TensorLike:
# utils.check(
# not use_reentrant,
# lambda: "torch.checkpoint: use_reentrant=True is not supported in Thunder",
# )
# # NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments
# # Let's raise a warning if any of these arguments are passed
# if context_fn is not None:
# warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored")
# if debug is not None:
# warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored")
# if determinism_check is not None:
# warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored")
# if preserve_rng_state is not None:
# warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored")
# return function(*args, **kwargs)


# @register_augmented_forward(
# "activation_checkpoint",
# )
# def _augmented_forward_checkpoint(
# function: Callable[..., TensorLike],
# *args: TensorLike,
# # context_fn: None | Callable[..., Any] = None,
# # debug: None | bool = None,
# # determinism_check: None | str = None,
# # preserve_rng_state: None | bool = None,
# # use_reentrant: bool = False,
# **kwargs: Any,
# ) -> TensorLike:
# result = function(*args, **kwargs)
# saved_for_backward = (function, args, kwargs)
# return result, saved_for_backward


# @register_backward(
# "activation_checkpoint",
# )
# def _backward_checkpoint(
# function,
# args,
# kwargs,
# *grad_outputs,
# ) -> tuple[None | TensorLike, ...]:
# from thunder.core.transforms import vjp

# result, grad = vjp(function)(args, grad_outputs, **kwargs)
# return grad #result


#
# Distributed operations
#
Expand Down
Loading