Skip to content

Commit

Permalink
use pytest.mark.filterwarnings
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Oct 18, 2024
1 parent c544dc5 commit 809865f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 2 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def f(x, y):


@requiresCUDA
@pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning")
def test_checkpoint_converter():
import torch.utils.checkpoint as checkpoint

Expand Down Expand Up @@ -552,10 +553,7 @@ def forward(self, x):
out.backward(g)

ref_g = torch.ones_like(ref_out)
with warnings.catch_warnings():
# FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
warnings.filterwarnings("ignore", category=FutureWarning)
ref_out.backward(ref_g)
ref_out.backward(ref_g)
torch.testing.assert_close(x.grad, x_ref.grad)
torch.testing.assert_close(tuple(model.parameters()), tuple(ref_model.parameters()))

Expand Down
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5263,7 +5263,7 @@ def _backward_checkpoint(
from thunder.core.transforms import vjp

result, grads = vjp(function)(args, grad_outputs, **kwargs)
return grads # result
return grads


#
Expand Down

0 comments on commit 809865f

Please sign in to comment.