Skip to content

Commit

Permalink
support MySubclass(...) called inside of torch.autograd.Function
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 6, 2024
1 parent 0235a21 commit 24b1628
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
3 changes: 2 additions & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,8 @@ def __init__(self, *args, **kwargs):
self._non_tensors,
output=self,
)
get_tracectx().add_bound_symbol(bsym)
current_trace = get_tracectx()
current_trace.scopes[-1].append(bsym)

def replace(self, **changes):
r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments.
Expand Down
64 changes: 63 additions & 1 deletion thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,43 @@
import torch

import thunder
from thunder.tests.framework import instantiate, NOTHING
from thunder.tests.framework import instantiate
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
from typing import Any


@torch._dynamo.allow_in_graph
class EncapsulateXandScale(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, scale: torch.Tensor):
return ScaleTensorSubclass(x, scale)

@staticmethod
def backward(ctx, grad):
return grad, None


def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass:
return EncapsulateXandScale.apply(x, scale)


@torch._dynamo.allow_in_graph
class ToScaleTensorSubclass(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return ScaleTensorSubclass.from_tensor(x)

@staticmethod
def backward(ctx, grad):
return grad


def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass:
return ToScaleTensorSubclass.apply(x)


class ScaleTensorSubclass(torch.Tensor):
_x: torch.Tensor
_scale: torch.Tensor
Expand Down Expand Up @@ -118,3 +148,35 @@ def f(x: torch.Tensor, scale: torch.Tensor):
expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))


@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_calling_converter(executor, device, _):

def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:
y = encapsulate_x_and_scale(x, scale)
return y

jitted = executor.make_callable(f)

dtype = torch.float32
shape = (2, 2)
x = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

def g(x: torch.Tensor) -> ScaleTensorSubclass:
y = to_scale_tensor_subclass(x)
return y

jitted = thunder.jit(g)
x = make_tensor(shape, device=device, dtype=dtype)

expected = g(x)
actual = jitted(x)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

0 comments on commit 24b1628

Please sign in to comment.