diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 070bdbe85b..461750690f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -914,11 +914,7 @@ def is_from_torch(fn): "For getting a full list of unsupported functions we recommend " \ "running `examine` from the `thunder.examine` module on your callable/PyTorch module." - sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges - if sharp_edges is SHARP_EDGES_OPTIONS.ALLOW: - warnings.warn(get_calling_opaque_torch_msg()) - - return _general_jit_sharp_edge(get_calling_opaque_torch_msg(), None) + return do_raise(NotImplementedError(get_calling_opaque_torch_msg())) return lookaside diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 3671bd12e7..459b0a182a 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -50,23 +50,23 @@ def skipif_not_pytorch_2_1(f): )(f) -def test_jitting_through_opaque_torch_symbols_sharp_edge(): - def no_sharp_edge(x): - # randn_like is in ltorch - return torch.randn_like(x) - - def sharp_edge(x): - # rand_like is not yet in ltroch - return torch.rand_like(x) - - x = torch.rand(1) - - jno_sharp_edge = thunder.jit(no_sharp_edge, sharp_edges="error") - jno_sharp_edge(x) - - jsharp_edge = thunder.jit(sharp_edge, sharp_edges="error") - with pytest.raises(JITSharpEdgeError): - jsharp_edge(x) +#def test_jitting_through_opaque_torch_symbols_sharp_edge(): +# def no_sharp_edge(x): +# # randn_like is in ltorch +# return torch.randn_like(x) +# +# def sharp_edge(x): +# # rand_like is not yet in ltroch +# return torch.rand_like(x) +# +# x = torch.rand(1) +# +# jno_sharp_edge = thunder.jit(no_sharp_edge, sharp_edges="error") +# jno_sharp_edge(x) +# +# jsharp_edge = thunder.jit(sharp_edge, sharp_edges="error") +# with pytest.raises(JITSharpEdgeError): +# jsharp_edge(x) def test_binary_add_tensors(): @@ -613,7 +613,7 @@ def test_nanogpt(): "falcon-7b-like", "falcon-40b-like", "codellama2-like", - pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)), + pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=NotImplementedError, reason="topk", strict=True)), ), ) @pytest.mark.parametrize( @@ -662,7 +662,7 @@ def test_litgpt_variants(name, device): "falcon-7b-like", "falcon-40b-like", "codellama2-like", - pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)), + pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=NotImplementedError, reason="topk", strict=True)), ), ) @pytest.mark.parametrize(