Skip to content

Commit

Permalink
test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 17, 2024
1 parent 49b6f23 commit 2a08222
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
5 changes: 5 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3157,6 +3157,8 @@ def backward_fn(saved_for_backward, cotangents):
enable_saved_for_backward_recomputation: None | bool = get_compile_option(
"enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation."
)
if enable_saved_for_backward_recomputation is None:
enable_saved_for_backward_recomputation = True
if enable_saved_for_backward_recomputation:
forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace)

Expand Down Expand Up @@ -3195,6 +3197,9 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
if thunder.core.proxies.ProxyTag.RECOMPUTE_IN_BACKWARD in thunder.core.proxies.unvariableify(p).tags
}

if not rematerializable:
return fwd_trace, bwd_trace

producers = find_producer_symbols(
fwd_trace,
tuple(unvariableify(i) for i in rematerializable),
Expand Down
48 changes: 46 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,9 +1753,10 @@ def f(x, y):
# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.saved_tensors) == 2

# We detach the saved tensors (which returns a new Python tensor backed by same storage)
assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr()
assert out.grad_fn.saved_tensors[1].data_ptr() == y.data_ptr()
# the order seems to be non-deterministic sometimes
assert {t.data_ptr() for t in out.grad_fn.saved_tensors} == {x.data_ptr(), y.data_ptr()}

g = torch.ones_like(out)
out.backward(g)
Expand All @@ -1768,6 +1769,49 @@ def f(x, y):
torch.testing.assert_close(y.grad, y_ref.grad)


@requiresCUDA
def test_checkpoint_max_memory():
import torch.utils.checkpoint

class Checkpoint(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args):
return torch.utils.checkpoint.checkpoint(self.module, *args, use_reentrant=False)

with torch.device("cuda:0"):
m = torch.nn.Sequential(
torch.nn.Linear(1024, 16),
torch.nn.ReLU(),
*[
Checkpoint(
torch.nn.Sequential(
torch.nn.Linear(16, 2048),
torch.nn.Linear(2048, 16),
torch.nn.ReLU(),
)
)
for _ in range(10)
],
torch.nn.Linear(16, 1024),
)
inps = torch.randn(512, 1024, requires_grad=True)

jm = thunder.jit(m, executors=()) # no rematerialization
mem_base = torch.cuda.memory_allocated()
torch.cuda.reset_accumulated_memory_stats()
res = jm(inps)
res.sum().backward()
mem_max = torch.cuda.max_memory_allocated()
# the rematerialization pass moved all(?) recomputation to the front,
# making the peak mem about 46MB.
# With checkpointing as coded in the model and recomputation where the
# values are used, we get about 12MB, so we put the barrier at 16MB
assert mem_max - mem_base < 16 * 2**20


def test_inconsistent_output_length_grad_transform():
from thunder.extend import OperatorExecutor
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down

0 comments on commit 2a08222

Please sign in to comment.