Skip to content

Commit

Permalink
Relax erroring on calling torch functions without thunder equivalents
Browse files Browse the repository at this point in the history
Some functions are actually OK, so we leave those alone.
  • Loading branch information
t-vi committed Mar 28, 2024
1 parent 94c9494 commit 198becc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
12 changes: 11 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,17 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable:
def is_from_torch(fn):
return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch")

if is_opaque(fn) and is_from_torch(fn):
has_tensor_arg = False
for a in args:
if isinstance(a, TensorProxy):
has_tensor_arg = True
break
if isinstance(a, Sequence):
if any(isinstance(i, TensorProxy) for i in a):
has_tensor_arg = True
break

if is_opaque(fn) and is_from_torch(fn) and has_tensor_arg:
if fn.__module__.startswith("torch._C"):
return lookaside

Expand Down
8 changes: 8 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ def foo(a, b, i):
assert_close(expected, actual)


# see https://github.com/Lightning-AI/lightning-thunder/issues/95
def test_get_default_dtype():
def foo():
return torch.get_default_dtype()

assert foo() == thunder.jit(foo)()


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
Expand Down

0 comments on commit 198becc

Please sign in to comment.