From c25a908888f15968182e3a7864497739da594225 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:58:25 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_inplace_copy.py | 42 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index b3cf518f4d..f89d2de4f2 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -155,27 +155,27 @@ def func3(x, y): traced_foo(a, b) -@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) -def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype): - def func(T0): - T1 = torch.sin(T0) +@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) +def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype): + def func(T0): + T1 = torch.sin(T0) prims.copy_(T1, T0) T2 = torch.cos(T1) prims.copy_(T2, T0) - # T1 & T2 should be returned as separate buffer, instead of sharing - # storage with T0 - return T1, T2 - - tdtype = ttorch.to_torch_dtype(dtype) - # This pattern is unsafe in general. Disabling sanity check to silence - # exception for testing - traced_foo = executor.make_callable(func, disable_inplace_copy_check=True) - a = make_tensor((4, 4), device=device, dtype=tdtype) - a_ref = a.clone() - - o_thunder = traced_foo(a) - o_eager = func(a_ref) - - assert_close(a_ref, a) - for o, o_ref in zip(o_thunder, o_eager): - assert_close(o, o_ref) + # T1 & T2 should be returned as separate buffer, instead of sharing + # storage with T0 + return T1, T2 + + tdtype = ttorch.to_torch_dtype(dtype) + # This pattern is unsafe in general. Disabling sanity check to silence + # exception for testing + traced_foo = executor.make_callable(func, disable_inplace_copy_check=True) + a = make_tensor((4, 4), device=device, dtype=tdtype) + a_ref = a.clone() + + o_thunder = traced_foo(a) + o_eager = func(a_ref) + + assert_close(a_ref, a) + for o, o_ref in zip(o_thunder, o_eager): + assert_close(o, o_ref)