Skip to content

Commit

Permalink
failing test case as starter
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 3, 2024
1 parent 92e79c4 commit fa05c82
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING

import pytest
import torch
from torch.utils import _pytree as pytree

import thunder
from thunder.tests.framework import instantiate, NOTHING
from thunder.core.proxies import SubclassTensorProxy
from thunder.tests.framework import instantiate, nvFuserExecutor
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
from typing import Any
from thunder.core.symbol import BoundSymbol


class ScaleTensorSubclass(torch.Tensor):
Expand Down Expand Up @@ -118,3 +122,41 @@ def f(x: torch.Tensor, scale: torch.Tensor):
expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))


@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_of_subclass_simple_math(executor, device, _):

def f(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:

y = ScaleTensorSubclass(data, scale)
out = x + y
return out

jitted = executor.make_callable(f)

dtype = torch.float32
shape = (2, 2)
x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype),
make_tensor((), device=device, dtype=dtype),
)
data = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

expected = f(x, data, scale)
actual = jitted(x, data, scale)
if executor == nvFuserExecutor:
with pytest.raises(Exception):
assert type(expected) is type(actual)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))
else:
assert type(expected) is type(actual)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

return_bsym: BoundSymbol = thunder.last_traces(jitted)[-1].bound_symbols[-1]
return_proxy = return_bsym.flat_args[0]
# FIXME(crcrpar): Implement a trace transform that corrects the output type of bsyms involving tensor subclasses
assert not isinstance(return_proxy, SubclassTensorProxy)

0 comments on commit fa05c82

Please sign in to comment.