Skip to content

Commit

Permalink
bwd run with tensor creation inside of trace
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 7, 2024
1 parent a285a91 commit 526ca4b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
20 changes: 11 additions & 9 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,15 +605,17 @@ def import_ctx(self):
# BoundSymbols of Symbols without Python implementations (either because they
# have Python implementations or defined call ctxs) are assumed to need
# a module import to run properly
assert self.sym.module is not None # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}

# TODO Include the other modules on the path?
# Also includes the root module of this (potential) submodule
if "." in module_name:
root_name = module_name.split(".")[0]
import_ctx[root_name] = sys.modules[root_name]
if self.sym.module is not None: # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}

# TODO Include the other modules on the path?
# Also includes the root module of this (potential) submodule
if "." in module_name:
root_name = module_name.split(".")[0]
import_ctx[root_name] = sys.modules[root_name]
else:
import_ctx = {}

self._import_ctx.update(import_ctx)
return self._import_ctx
Expand Down
1 change: 0 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,6 @@ def _get_gradfn_and_executor(
ex_grad_transform: None | Callable = ex.get_grad_transform(bsym.sym)
if ex_grad_transform is not None:
return ex_grad_transform, ex
break

# If the executor doesn't define its own grad transform, this just returns the default grad transform for the bsym
gradfn = _grad_fn_map.get(bsym.sym.id, None)
Expand Down
56 changes: 28 additions & 28 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def g(x: torch.Tensor) -> ScaleTensorSubclass:

@instantiate(
dtypes=(thunder.core.dtypes.float32,),
decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("no_bwd", "bwd")),),
decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("fwd_only", "with_bwd")),),
)
def test_func_of_subclass_simple_math(executor, device, _, requires_grad):

Expand All @@ -198,37 +198,37 @@ def f(x: ScaleTensorSubclass, y: ScaleTensorSubclass) -> torch.Tensor:

dtype = torch.float32
shape = (2, 2)
x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
y = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)

expected = f(x, y)
actual = jitted(x, y)
assert type(expected) is type(actual)
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()

# def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
# y = EncapsulateXandScale.apply(data, scale)
# out = x + y
# return out
#
# jitted = executor.make_callable(g)
#
# x = ScaleTensorSubclass(
# make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
# make_tensor((), device=device, dtype=dtype),
# )
# data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
# scale = make_tensor((), device=device, dtype=dtype)
# y = ScaleTensorSubclass(
# make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
# make_tensor((), device=device, dtype=dtype),
# )
#
# expected = g(x, data, scale)
# actual = jitted(x, data, scale)
# expected = f(x, y)
# actual = jitted(x, y)
# assert type(expected) is type(actual)
# torch.testing.assert_close(expected, actual)
# if requires_grad:
# actual.mean().backward()

def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
y = EncapsulateXandScale.apply(data, scale)
out = x + y
return out

jitted = executor.make_callable(g)

x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
scale = make_tensor((), device=device, dtype=dtype)

expected = g(x, data, scale)
actual = jitted(x, data, scale)
assert type(expected) is type(actual)
torch.testing.assert_close(expected, actual)

0 comments on commit 526ca4b

Please sign in to comment.