Skip to content

Commit

Permalink
update torch.where to throw error for single input (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
k223kim authored Apr 17, 2024
1 parent eaa0dd2 commit 7a887ad
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
18 changes: 17 additions & 1 deletion thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,10 +2484,26 @@ def where_sample_generator(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(pred, a, b)


def where_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)
err_msg = r"torch.where\(\) does not support only specifying a condition"
yield (
SampleInput(
make(
5,
)
),
NotImplementedError,
err_msg,
)
yield (SampleInput(make(2, 1, 2)), NotImplementedError, err_msg)


where_opinfo = OpInfo(
clang.where,
ltorch.where,
supports_grad=True,
sample_input_generator=where_sample_generator,
error_input_generator=where_error_generator,
torch_reference=torch.where,
)
conditional_and_mask_ops.append(where_opinfo)
Expand Down
10 changes: 8 additions & 2 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,10 @@ def test_nanogpt():
"falcon-7b-like",
"falcon-40b-like",
"codellama2-like",
pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
pytest.param(
"mixtral-like",
marks=pytest.mark.xfail(raises=(NotImplementedError, TypeError), reason="topk and where", strict=True),
),
),
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -670,7 +673,10 @@ def test_litgpt_variants(name, device):
"falcon-7b-like",
"falcon-40b-like",
"codellama2-like",
pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
pytest.param(
"mixtral-like",
marks=pytest.mark.xfail(raises=(NotImplementedError, TypeError), reason="topk and where", strict=True),
),
),
)
@pytest.mark.parametrize(
Expand Down
9 changes: 8 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,14 @@ def tril(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = Non


@torchsymbol(torch.where, is_method=True)
def where(pred: TensorLike, a: Number | TensorLike, b: Number | TensorLike, /) -> TensorLike:
def where(
pred: TensorLike, a: None | Number | TensorLike = None, b: None | Number | TensorLike = None, /
) -> TensorLike:
utils.check(
isinstance(a, (Number, TensorProxy)) and isinstance(b, (Number, TensorProxy)),
lambda: f"torch.where() does not support only specifying a condition",
exception_type=NotImplementedError,
)
return clang.where(pred, a, b)


Expand Down

0 comments on commit 7a887ad

Please sign in to comment.