diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index d8827eedbd..72c9a75ef0 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -605,15 +605,17 @@ def import_ctx(self): # BoundSymbols of Symbols without Python implementations (either because they # have Python implementations or defined call ctxs) are assumed to need # a module import to run properly - assert self.sym.module is not None # TODO: Is this a valid assumption? - module_name = self.sym.module.__name__ - import_ctx = {module_name: self.sym.module} - - # TODO Include the other modules on the path? - # Also includes the root module of this (potential) submodule - if "." in module_name: - root_name = module_name.split(".")[0] - import_ctx[root_name] = sys.modules[root_name] + if self.sym.module is not None: # TODO: Is this a valid assumption? + module_name = self.sym.module.__name__ + import_ctx = {module_name: self.sym.module} + + # TODO Include the other modules on the path? + # Also includes the root module of this (potential) submodule + if "." in module_name: + root_name = module_name.split(".")[0] + import_ctx[root_name] = sys.modules[root_name] + else: + import_ctx = {} self._import_ctx.update(import_ctx) return self._import_ctx diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 05f4bbc237..3998eef630 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1450,7 +1450,6 @@ def _get_gradfn_and_executor( ex_grad_transform: None | Callable = ex.get_grad_transform(bsym.sym) if ex_grad_transform is not None: return ex_grad_transform, ex - break # If the executor doesn't define its own grad transform, this just returns the default grad transform for the bsym gradfn = _grad_fn_map.get(bsym.sym.id, None) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 1b5fd04b21..118072deb9 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -186,7 +186,7 @@ def g(x: torch.Tensor) -> ScaleTensorSubclass: @instantiate( dtypes=(thunder.core.dtypes.float32,), - decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("no_bwd", "bwd")),), + decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("fwd_only", "with_bwd")),), ) def test_func_of_subclass_simple_math(executor, device, _, requires_grad): @@ -198,37 +198,37 @@ def f(x: ScaleTensorSubclass, y: ScaleTensorSubclass) -> torch.Tensor: dtype = torch.float32 shape = (2, 2) - x = ScaleTensorSubclass( - make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), - make_tensor((), device=device, dtype=dtype), - ) - y = ScaleTensorSubclass( - make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), - make_tensor((), device=device, dtype=dtype), - ) - - expected = f(x, y) - actual = jitted(x, y) - assert type(expected) is type(actual) - torch.testing.assert_close(expected, actual) - if requires_grad: - actual.mean().backward() - - # def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - # y = EncapsulateXandScale.apply(data, scale) - # out = x + y - # return out - # - # jitted = executor.make_callable(g) - # # x = ScaleTensorSubclass( # make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), # make_tensor((), device=device, dtype=dtype), # ) - # data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) - # scale = make_tensor((), device=device, dtype=dtype) + # y = ScaleTensorSubclass( + # make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), + # make_tensor((), device=device, dtype=dtype), + # ) # - # expected = g(x, data, scale) - # actual = jitted(x, data, scale) + # expected = f(x, y) + # actual = jitted(x, y) # assert type(expected) is type(actual) # torch.testing.assert_close(expected, actual) + # if requires_grad: + # actual.mean().backward() + + def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + y = EncapsulateXandScale.apply(data, scale) + out = x + y + return out + + jitted = executor.make_callable(g) + + x = ScaleTensorSubclass( + make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), + make_tensor((), device=device, dtype=dtype), + ) + data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + scale = make_tensor((), device=device, dtype=dtype) + + expected = g(x, data, scale) + actual = jitted(x, data, scale) + assert type(expected) is type(actual) + torch.testing.assert_close(expected, actual)