diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 83b8718586..dabd0b204b 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2484,10 +2484,26 @@ def where_sample_generator(op, device, dtype, requires_grad, **kwargs): yield SampleInput(pred, a, b) +def where_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + err_msg = r"torch.where\(\) does not support only specifying a condition" + yield ( + SampleInput( + make( + 5, + ) + ), + NotImplementedError, + err_msg, + ) + yield (SampleInput(make(2, 1, 2)), NotImplementedError, err_msg) + + where_opinfo = OpInfo( - clang.where, + ltorch.where, supports_grad=True, sample_input_generator=where_sample_generator, + error_input_generator=where_error_generator, torch_reference=torch.where, ) conditional_and_mask_ops.append(where_opinfo) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 330d3df115..a1dfa41898 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -621,7 +621,10 @@ def test_nanogpt(): "falcon-7b-like", "falcon-40b-like", "codellama2-like", - pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)), + pytest.param( + "mixtral-like", + marks=pytest.mark.xfail(raises=(NotImplementedError, TypeError), reason="topk and where", strict=True), + ), ), ) @pytest.mark.parametrize( @@ -670,7 +673,10 @@ def test_litgpt_variants(name, device): "falcon-7b-like", "falcon-40b-like", "codellama2-like", - pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)), + pytest.param( + "mixtral-like", + marks=pytest.mark.xfail(raises=(NotImplementedError, TypeError), reason="topk and where", strict=True), + ), ), ) @pytest.mark.parametrize( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 780125a688..ac8871dbbb 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1565,7 +1565,14 @@ def tril(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = Non @torchsymbol(torch.where, is_method=True) -def where(pred: TensorLike, a: Number | TensorLike, b: Number | TensorLike, /) -> TensorLike: +def where( + pred: TensorLike, a: None | Number | TensorLike = None, b: None | Number | TensorLike = None, / +) -> TensorLike: + utils.check( + isinstance(a, (Number, TensorProxy)) and isinstance(b, (Number, TensorProxy)), + lambda: f"torch.where() does not support only specifying a condition", + exception_type=NotImplementedError, + ) return clang.where(pred, a, b)