Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add glu op #1477

Merged
merged 12 commits into from
Dec 3, 2024
1 change: 0 additions & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import thunder.core.utils as utils

import thunder.torch as ltorch
from thunder.torch import DeviceLike, dtypeLike, TensorLike

from thunder.extend import OperatorExecutor, register_executor, add_always_executor

Expand Down
23 changes: 13 additions & 10 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5488,6 +5488,7 @@ def var_sample_generator(op, device, dtype, requires_grad):
yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad))


# glu requires that value of the shape of the input at index dim be even
def glu_sample_generator(op, device, dtype, requires_grad, **kwargs):
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
make = partial(
make_tensor,
Expand All @@ -5497,20 +5498,20 @@ def glu_sample_generator(op, device, dtype, requires_grad, **kwargs):
)

cases = (
((4,),),
((4,), None),
((3, 4), 1),
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
((3, 4),),
((3, 4), None),
((3, 4), -1),
((4, 3), 0),
((4, 5, 8),),
((4, 5, 8), None),
((4, 5, 8), 0),
)

for case in cases:
if len(case) == 1:
yield SampleInput(make(*case))
for shape, dim in cases:
if dim is None:
yield SampleInput(make(*shape))
else:
shape, dim = case
yield SampleInput(make(shape), dim)
yield SampleInput(make(*shape), dim)


def glu_error_generator(op, device, **kwargs):
Expand All @@ -5520,16 +5521,18 @@ def glu_error_generator(op, device, **kwargs):
dtype=torch.float,
)
err_msg = r"Halving dimension must be even, but dimension .* is size .*"
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
# The value of the shape of the input in the default (last) dim is odd, which is unsupported.
yield (SampleInput(make((3,))), RuntimeError, err_msg)
yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg)
# The value of the shape of the input at index dim=1 is odd, which is unsupported.
yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg)


glu_opinfo = OpInfo(
ltorch.glu,
sample_input_generator=glu_sample_generator,
error_input_generator=glu_error_generator,
dtypes=(datatypes.floating, datatypes.complexfloating),
dtypes=(datatypes.inexact,),
torch_reference=torch.nn.functional.glu,
test_directives=(),
)
Expand All @@ -5540,7 +5543,7 @@ def glu_error_generator(op, device, **kwargs):
ltorch.mean,
sample_input_generator=reduction_sample_generator,
torch_reference=torch.mean,
dtypes=(datatypes.floating, datatypes.complexfloating),
dtypes=(datatypes.inexact,),
test_directives=(
# PyTorch doesn't support CPU and CUDA complex half mean
DecorateInfo(
Expand Down
10 changes: 3 additions & 7 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2688,12 +2688,8 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy


@torchsymbol(torch.nn.functional.glu, is_method=False)
def glu(a: TensorProxy, /, dim: int = -1):
utils.check(
-a.ndim <= dim < a.ndim,
lambda: f"Dimension out of range (expected to be in range [{-a.ndim}, {a.ndim - 1}], but got {dim})",
IndexError,
)
def glu(a: TensorProxy, /, dim: int = -1) -> TensorProxy:
dim = utils.canonicalize_dim(len(a.shape), dim)
utils.check(
a.shape[dim] % 2 == 0,
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
Expand All @@ -2705,7 +2701,7 @@ def glu(a: TensorProxy, /, dim: int = -1):


@torchsymbol(torch.mean, is_method=True)
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None):
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy:
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
dtype = dtype if dtype is not None else a.dtype
utils.check(
not utils.is_integer_dtype(dtype) and not utils.is_boolean_dtype(dtype),
Expand Down
Loading