Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add glu op #1477

Merged
merged 12 commits into from
Dec 3, 2024
1 change: 0 additions & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import thunder.core.utils as utils

import thunder.torch as ltorch
from thunder.torch import DeviceLike, dtypeLike, TensorLike

from thunder.extend import OperatorExecutor, register_executor, add_always_executor

Expand Down
23 changes: 8 additions & 15 deletions thunder/extend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
from collections.abc import Callable, Sequence, Hashable
import enum
import sys
import os
import itertools
from typing import Any
from collections.abc import Sequence
from collections.abc import Callable
from collections.abc import Hashable
import os
import sys
from types import ModuleType
import warnings
from functools import cache, partial
from typing import Any

import torch.cuda


from thunder.core.utils import check
from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer
from thunder.core.trace import TraceCtx
from thunder.core.proxies import Proxy, TensorProxy, proxy
from thunder.core.baseutils import run_once
from thunder.core.dtypes import to_torch_dtype
from thunder.core.devices import to_torch_device
from thunder.core.dtypes import to_torch_dtype
from thunder.core.proxies import Proxy, TensorProxy, proxy
from thunder.core.pytree import tree_map
from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer
from thunder.core.trace import TraceCtx

__all__ = [
"register_executor",
Expand Down
72 changes: 60 additions & 12 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import itertools
import math
import operator
from collections import namedtuple
from collections.abc import Sequence
from collections.abc import Callable, Generator, Iterable, Sequence
from functools import partial, wraps
import itertools
import math
from numbers import Number
from typing import Union, Optional, Tuple, Any
from collections.abc import Callable
from collections.abc import Generator, Iterable
import operator
from typing import Any

import numpy as np
import pytest
Expand All @@ -23,14 +21,13 @@
import thunder.core.dtypes as datatypes
from thunder.core.dtypes import to_dtype, to_torch_dtype
import thunder.core.prims as prims
import thunder.executors as executors
import thunder.torch as ltorch
from thunder.core.pytree import tree_map
from thunder.core.symbol import Symbol
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS, version_between
import thunder.executors as executors
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS
from thunder.tests.make_tensor import make_tensor
import thunder.extend as extend
import thunder.tests.bf16
import thunder.torch as ltorch

#
# Helpful constants and utility functions
Expand Down Expand Up @@ -5491,11 +5488,62 @@ def var_sample_generator(op, device, dtype, requires_grad):
yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad))


# glu requires that value of the shape of the input at index dim be even
def glu_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
)

cases = (
((4,), None),
((3, 4), 1),
((3, 4), None),
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
((3, 4), -1),
((4, 3), 0),
((4, 5, 8), None),
((4, 5, 8), 0),
)

for shape, dim in cases:
if dim is None:
yield SampleInput(make(*shape))
else:
yield SampleInput(make(*shape), dim)


def glu_error_generator(op, device, **kwargs):
make = partial(
make_tensor,
device=device,
dtype=torch.float,
)
err_msg = r"Halving dimension must be even, but dimension .* is size .*"
# The value of the shape of the input in the default (last) dim is odd, which is unsupported.
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
yield (SampleInput(make((3,))), RuntimeError, err_msg)
yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg)
# The value of the shape of the input at index dim=1 is odd, which is unsupported.
yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg)


glu_opinfo = OpInfo(
ltorch.glu,
sample_input_generator=glu_sample_generator,
error_input_generator=glu_error_generator,
dtypes=(datatypes.inexact,),
torch_reference=torch.nn.functional.glu,
test_directives=(),
)
reduction_ops.append(glu_opinfo)

beverlylytle marked this conversation as resolved.
Show resolved Hide resolved

mean_opinfo = OpInfo(
ltorch.mean,
sample_input_generator=reduction_sample_generator,
torch_reference=torch.mean,
dtypes=(datatypes.floating, datatypes.complexfloating),
dtypes=(datatypes.inexact,),
test_directives=(
# PyTorch doesn't support CPU and CUDA complex half mean
DecorateInfo(
Expand Down
17 changes: 15 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike:


#
# Elementwise unary operaitons
# Elementwise unary operations
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
#
# TODO Add type annotations

Expand Down Expand Up @@ -2687,8 +2687,21 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy
register_method("clone", clone)


@torchsymbol(torch.nn.functional.glu, is_method=False)
def glu(a: TensorProxy, /, dim: int = -1) -> TensorProxy:
dim = utils.canonicalize_dim(len(a.shape), dim)
utils.check(
a.shape[dim] % 2 == 0,
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
)
chunk_size = a.shape[dim] // 2
left, right = split(a, (chunk_size, chunk_size), dim=dim)
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
out = left * sigmoid(right)
return out


@torchsymbol(torch.mean, is_method=True)
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None):
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy:
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
dtype = dtype if dtype is not None else a.dtype
utils.check(
not utils.is_integer_dtype(dtype) and not utils.is_boolean_dtype(dtype),
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@
torch.nn.functional.fractional_max_pool3d,
torch.nn.functional.fractional_max_pool3d_with_indices,
torch.nn.functional.gaussian_nll_loss,
torch.nn.functional.glu,
torch.nn.functional.grid_sample,
torch.nn.functional.gumbel_softmax,
torch.nn.functional.hardshrink,
Expand Down
Loading