diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 9d02f52740..ea60155a7c 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -580,9 +580,9 @@ def f(x, y, idx, src): jitted = executor.make_callable(f) o = jitted(a, b, idx, src) - assert a.allclose(a_) - assert b.allclose(b_) - assert o.allclose(o_) + torch.testing.assert_close(a, a_) + torch.testing.assert_close(b, b_) + torch.testing.assert_close(o, o_) @instantiate(