diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index c3c7c798af..8f87bb82ae 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -905,7 +905,17 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: def is_from_torch(fn): return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch") - if is_opaque(fn) and is_from_torch(fn): + has_tensor_arg = False + for a in args: + if isinstance(a, TensorProxy): + has_tensor_arg = True + break + if isinstance(a, Sequence): + if any(isinstance(i, TensorProxy) for i in a): + has_tensor_arg = True + break + + if is_opaque(fn) and is_from_torch(fn) and has_tensor_arg: if fn.__module__.startswith("torch._C"): return lookaside diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index d1d55073d7..330d3df115 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -467,6 +467,14 @@ def foo(a, b, i): assert_close(expected, actual) +# see https://github.com/Lightning-AI/lightning-thunder/issues/95 +def test_get_default_dtype(): + def foo(): + return torch.get_default_dtype() + + assert foo() == thunder.jit(foo)() + + @pytest.mark.parametrize( "device", ("cpu", "cuda"),