|
8 | 8 |
|
9 | 9 | from thunder.core import dtypes
|
10 | 10 | from thunder.core.prims import PrimIDs
|
11 |
| -from thunder.tests.framework import ops, requiresCUDA |
| 11 | +from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING |
12 | 12 | from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput
|
13 | 13 | from thunder.tests.make_tensor import make_tensor
|
14 | 14 | from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
|
@@ -178,3 +178,55 @@ def test_parse_resnet18(train: bool):
|
178 | 178 | jitted = thunder.jit(model)
|
179 | 179 | x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device)
|
180 | 180 | torch.testing.assert_close(jitted(x), ref_model(x))
|
| 181 | + |
| 182 | + |
| 183 | +@instantiate( |
| 184 | + dtypes=NOTHING, |
| 185 | +) |
| 186 | +def test_inplace_to_views(executor, device, _): |
| 187 | + import thunder |
| 188 | + |
| 189 | + def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 190 | + c = torch.exp(a) |
| 191 | + d = torch.tanh(b) |
| 192 | + |
| 193 | + e = c.view(-1) |
| 194 | + e += d.flatten() |
| 195 | + |
| 196 | + d.div_(a) |
| 197 | + return c, d, e |
| 198 | + |
| 199 | + a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2)) |
| 200 | + a_, b_ = a.clone().detach(), b.clone().detach() |
| 201 | + |
| 202 | + jittd_f = thunder.jit(f, executors=executor.executors_list()) |
| 203 | + |
| 204 | + c, d, e = jittd_f(a, b) |
| 205 | + c_, d_, e_ = f(a_, b_) |
| 206 | + |
| 207 | + torch.testing.assert_close((c, d, e), (c_, d_, e_)) |
| 208 | + |
| 209 | + |
| 210 | +@instantiate( |
| 211 | + dtypes=NOTHING, |
| 212 | +) |
| 213 | +def test_error_of_inplace_to_views(executor, device, _): |
| 214 | + import thunder |
| 215 | + |
| 216 | + def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 217 | + c = torch.exp(a) |
| 218 | + d = torch.tanh(b) |
| 219 | + |
| 220 | + e = c.flatten() |
| 221 | + e += d.flatten() |
| 222 | + |
| 223 | + d.div_(a) |
| 224 | + return c, d, e |
| 225 | + |
| 226 | + a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2)) |
| 227 | + a_, b_ = a.clone().detach(), b.clone().detach() |
| 228 | + |
| 229 | + jittd_f = thunder.jit(f, executors=executor.executors_list()) |
| 230 | + |
| 231 | + with pytest.raises(NotImplementedError, match="in-place op of `torch.Tensor.add_` to `torch.flatten` output"): |
| 232 | + c, d, e = jittd_f(a, b) |
0 commit comments