Skip to content

Commit b27db92

Browse files
committed
some test cases
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent cd5e5df commit b27db92

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

thunder/tests/test_inplace_functionalization.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from thunder.core import dtypes
1010
from thunder.core.prims import PrimIDs
11-
from thunder.tests.framework import ops, requiresCUDA
11+
from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING
1212
from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput
1313
from thunder.tests.make_tensor import make_tensor
1414
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
@@ -178,3 +178,55 @@ def test_parse_resnet18(train: bool):
178178
jitted = thunder.jit(model)
179179
x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device)
180180
torch.testing.assert_close(jitted(x), ref_model(x))
181+
182+
183+
@instantiate(
184+
dtypes=NOTHING,
185+
)
186+
def test_inplace_to_views(executor, device, _):
187+
import thunder
188+
189+
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
190+
c = torch.exp(a)
191+
d = torch.tanh(b)
192+
193+
e = c.view(-1)
194+
e += d.flatten()
195+
196+
d.div_(a)
197+
return c, d, e
198+
199+
a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2))
200+
a_, b_ = a.clone().detach(), b.clone().detach()
201+
202+
jittd_f = thunder.jit(f, executors=executor.executors_list())
203+
204+
c, d, e = jittd_f(a, b)
205+
c_, d_, e_ = f(a_, b_)
206+
207+
torch.testing.assert_close((c, d, e), (c_, d_, e_))
208+
209+
210+
@instantiate(
211+
dtypes=NOTHING,
212+
)
213+
def test_error_of_inplace_to_views(executor, device, _):
214+
import thunder
215+
216+
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
217+
c = torch.exp(a)
218+
d = torch.tanh(b)
219+
220+
e = c.flatten()
221+
e += d.flatten()
222+
223+
d.div_(a)
224+
return c, d, e
225+
226+
a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2))
227+
a_, b_ = a.clone().detach(), b.clone().detach()
228+
229+
jittd_f = thunder.jit(f, executors=executor.executors_list())
230+
231+
with pytest.raises(NotImplementedError, match="in-place op of `torch.Tensor.add_` to `torch.flatten` output"):
232+
c, d, e = jittd_f(a, b)

0 commit comments

Comments
 (0)