Skip to content

Commit

Permalink
add enable_fallback_to_torch compile option, move registration into g…
Browse files Browse the repository at this point in the history
…eneral_jit_lookaside; fix test
  • Loading branch information
kiya00 committed Jul 18, 2024
1 parent 7e62ff1 commit 210ed60
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
8 changes: 0 additions & 8 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6155,14 +6155,6 @@ def _impl(fn, *args, **kwargs):
return _interpret_call(unbound_fn, slf, *args, **kwargs)

# (2) Handles lookasides
# Check if the torch operator is not registered in the _torch_to_thunder_function_map, fallback to automatic registration.
from thunder.torch import get_torch_fallback_operators_module

m = get_torch_fallback_operators_module(fn)
if m is not None:
from thunder.torch import meta_adaptor, register_torch_op

register_torch_op(fn, meta_adaptor(fn), m)
lookaside_fn: INTERPRETER_SIGNALS | None | Callable = compilectx.lookaside(fn, *args, **kwargs)
if lookaside_fn is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
# Happens with sharp edges, for example
Expand Down
11 changes: 11 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,17 @@ def proxy_recursion(v):

# TODO Document this function (with steps)
def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable:
cd = get_compile_data()
enable_fallback_to_torch = cd.compile_options.get("enable_fallback_to_torch", True)
if enable_fallback_to_torch:
from thunder.torch import get_torch_fallback_operators_module

# Check if the torch operator is not registered in the _torch_to_thunder_function_map, fallback to automatic registration.
if (m := get_torch_fallback_operators_module(fn)) is not None:
from thunder.torch import meta_adaptor, register_torch_op

register_torch_op(fn, meta_adaptor(fn), m)

# Identifies the lookaside
lookaside: None | Callable

Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def should_error(x):
jno_error = thunder.jit(no_error)
jno_error(x)

jshould_error = thunder.jit(should_error)
jshould_error = thunder.jit(should_error, enable_fallback_to_torch=False)
with pytest.raises(NotImplementedError):
jshould_error(x)

Expand Down

0 comments on commit 210ed60

Please sign in to comment.