Skip to content

Commit

Permalink
Relax erroring on calling torch functions without thunder equivalents (
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Apr 3, 2024
1 parent 26d6255 commit 4d026d9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
12 changes: 11 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,17 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable:
def is_from_torch(fn):
return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch")

if is_opaque(fn) and is_from_torch(fn):
has_tensor_arg = False
for a in args:
if isinstance(a.value, TensorProxy):
has_tensor_arg = True
break
if isinstance(a.value, Sequence):
if any(isinstance(i, TensorProxy) for i in a.value):
has_tensor_arg = True
break

if is_opaque(fn) and is_from_torch(fn) and has_tensor_arg:
if fn.__module__.startswith("torch._C"):
return lookaside

Expand Down
8 changes: 8 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ def foo(a, b, i):
assert_close(expected, actual)


# see https://github.com/Lightning-AI/lightning-thunder/issues/95
def test_get_default_dtype():
def foo():
return torch.get_default_dtype()

assert foo() == thunder.jit(foo)()


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
Expand Down

0 comments on commit 4d026d9

Please sign in to comment.