Skip to content

Commit

Permalink
Error on torch functions which are not part of the ltorch
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved committed Mar 26, 2024
1 parent bdf5c3f commit 693509d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
12 changes: 8 additions & 4 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,10 +908,14 @@ def is_from_torch(fn):
# Torch functions have __name__ defined
fn_name = f"{fn.__module__}.{fn.__name__}"

# For now, only torch-like opaque functions are sharp edges
return _general_jit_sharp_edge(
f"Trying to call function {fn_name}, but it's unsupported. Please file an issue requesting support.",
None,
# TODO: maybe convert to sharp edge later on
return do_raise(
NotImplementedError(
f"Trying to call function {fn_name}, but it is not yet supported. "
"Please file an issue requesting support. "
"For getting a full list of unsupported functions we recommend "
"running `examine` from the `thunder.examine` module on your callable/PyTorch module."
)
)

return lookaside
Expand Down
16 changes: 8 additions & 8 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,23 @@ def skipif_not_pytorch_2_1(f):
)(f)


def test_jitting_through_opaque_torch_symbols_sharp_edge():
def no_sharp_edge(x):
def test_jitting_through_opaque_torch_symbols_errors():
def no_error(x):
# randn_like is in ltorch
return torch.randn_like(x)

def sharp_edge(x):
def should_error(x):
# rand_like is not yet in ltroch
return torch.rand_like(x)

x = torch.rand(1)

jno_sharp_edge = thunder.jit(no_sharp_edge, sharp_edges="error")
jno_sharp_edge(x)
jno_error = thunder.jit(no_error)
jno_error(x)

jsharp_edge = thunder.jit(sharp_edge, sharp_edges="error")
with pytest.raises(JITSharpEdgeError):
jsharp_edge(x)
jshould_error = thunder.jit(should_error)
with pytest.raises(NotImplementedError):
jshould_error(x)


def test_binary_add_tensors():
Expand Down

0 comments on commit 693509d

Please sign in to comment.