Skip to content

Commit

Permalink
mv test to test_dynamo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Oct 11, 2024
1 parent 209a3f2 commit 1d1ab80
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
40 changes: 40 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,43 @@ def test_thundercompiler_optim_step(executor, device, dtype, optim):
tuple(ref_model.parameters()),
msg=lambda s: f"{i+1}-iter {s}",
)


def test_torch_checkpoint_dynamo():
import torch.utils.checkpoint as checkpoint
import torch.nn as nn
from thunder.dynamo import ThunderCompiler

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20) # nn.ReLU() #
self.layer2 = nn.Linear(20, 20)
self.layer3 = nn.ReLU() # nn.ReLU() #

def forward(self, x):
# Use checkpointing for layers where you want to save memory
x = torch.sin(x)
x = checkpoint.checkpoint(self.layer1, x) # Checkpoint layer1 self.layer1(x) #
x = checkpoint.checkpoint(self.layer2, x) # Checkpoint layer2
x = self.layer3(x) # No checkpoint for layer3
return x

# Input tensor
x = torch.randn(5, 10).requires_grad_()
model = SimpleModel().train()
backend = ThunderCompiler()
jf = torch.compile(backend=backend)(model)
# jf = thunder.jit(f)
out = jf(x)
# print(thunder.last_traces(jf)[0])
# print(thunder.last_traces(backend.subgraph_infos[0].thunder_compiled_fns[0])[0])
# print(thunder.last_backward_traces(backend.subgraph_infos[0].thunder_compiled_fns[0])[0])

g = torch.ones_like(out)
out.backward(g)

x_ref = x.detach().requires_grad_()
out_ref = model(x_ref)
out_ref.backward(g)
torch.testing.assert_close(x.grad, x_ref.grad)
40 changes: 0 additions & 40 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,46 +1736,6 @@ def f(x):
torch.testing.assert_close(x.grad, x_ref.grad)


def test_torch_checkpoint_dynamo():
import torch.utils.checkpoint as checkpoint
import torch.nn as nn
from thunder.dynamo import ThunderCompiler

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20) # nn.ReLU() #
self.layer2 = nn.Linear(20, 20)
self.layer3 = nn.ReLU() # nn.ReLU() #

def forward(self, x):
# Use checkpointing for layers where you want to save memory
x = torch.sin(x)
x = checkpoint.checkpoint(self.layer1, x) # Checkpoint layer1 self.layer1(x) #
x = checkpoint.checkpoint(self.layer2, x) # Checkpoint layer2
x = self.layer3(x) # No checkpoint for layer3
return x

# Input tensor
x = torch.randn(5, 10).requires_grad_()
model = SimpleModel().train()
backend = ThunderCompiler()
jf = torch.compile(backend=backend)(model)
# jf = thunder.jit(f)
out = jf(x)
# print(thunder.last_traces(jf)[0])
print(thunder.last_traces(backend.subgraph_infos[0].thunder_compiled_fns[0])[0])
print(thunder.last_backward_traces(backend.subgraph_infos[0].thunder_compiled_fns[0])[0])

g = torch.ones_like(out)
out.backward(g)

x_ref = x.detach().requires_grad_()
out_ref = model(x_ref)
out_ref.backward(g)
torch.testing.assert_close(x.grad, x_ref.grad)


def test_inconsistent_output_length_grad_transform():
from thunder.extend import OperatorExecutor
from thunder.core.proxies import AnyProxy, TensorProxy
Expand Down

0 comments on commit 1d1ab80

Please sign in to comment.