Skip to content

Commit

Permalink
Add glu op (#1477)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored Dec 3, 2024
1 parent 29adb08 commit 2228304
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 31 deletions.
1 change: 0 additions & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import thunder.core.utils as utils

import thunder.torch as ltorch
from thunder.torch import DeviceLike, dtypeLike, TensorLike

from thunder.extend import OperatorExecutor, register_executor, add_always_executor

Expand Down
23 changes: 8 additions & 15 deletions thunder/extend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
from collections.abc import Callable, Sequence, Hashable
import enum
import sys
import os
import itertools
from typing import Any
from collections.abc import Sequence
from collections.abc import Callable
from collections.abc import Hashable
import os
import sys
from types import ModuleType
import warnings
from functools import cache, partial
from typing import Any

import torch.cuda


from thunder.core.utils import check
from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer
from thunder.core.trace import TraceCtx
from thunder.core.proxies import Proxy, TensorProxy, proxy
from thunder.core.baseutils import run_once
from thunder.core.dtypes import to_torch_dtype
from thunder.core.devices import to_torch_device
from thunder.core.dtypes import to_torch_dtype
from thunder.core.proxies import Proxy, TensorProxy, proxy
from thunder.core.pytree import tree_map
from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer
from thunder.core.trace import TraceCtx

__all__ = [
"register_executor",
Expand Down
72 changes: 60 additions & 12 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import itertools
import math
import operator
from collections import namedtuple
from collections.abc import Sequence
from collections.abc import Callable, Generator, Iterable, Sequence
from functools import partial, wraps
import itertools
import math
from numbers import Number
from typing import Union, Optional, Tuple, Any
from collections.abc import Callable
from collections.abc import Generator, Iterable
import operator
from typing import Any

import numpy as np
import pytest
Expand All @@ -23,14 +21,13 @@
import thunder.core.dtypes as datatypes
from thunder.core.dtypes import to_dtype, to_torch_dtype
import thunder.core.prims as prims
import thunder.executors as executors
import thunder.torch as ltorch
from thunder.core.pytree import tree_map
from thunder.core.symbol import Symbol
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS, version_between
import thunder.executors as executors
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS
from thunder.tests.make_tensor import make_tensor
import thunder.extend as extend
import thunder.tests.bf16
import thunder.torch as ltorch

#
# Helpful constants and utility functions
Expand Down Expand Up @@ -5491,11 +5488,62 @@ def var_sample_generator(op, device, dtype, requires_grad):
yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad))


# glu requires that value of the shape of the input at index dim be even
def glu_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(
make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
)

cases = (
((4,), None),
((3, 4), 1),
((3, 4), None),
((3, 4), -1),
((4, 3), 0),
((4, 5, 8), None),
((4, 5, 8), 0),
)

for shape, dim in cases:
if dim is None:
yield SampleInput(make(*shape))
else:
yield SampleInput(make(*shape), dim)


def glu_error_generator(op, device, **kwargs):
make = partial(
make_tensor,
device=device,
dtype=torch.float,
)
err_msg = r"Halving dimension must be even, but dimension .* is size .*"
# The value of the shape of the input in the default (last) dim is odd, which is unsupported.
yield (SampleInput(make((3,))), RuntimeError, err_msg)
yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg)
# The value of the shape of the input at index dim=1 is odd, which is unsupported.
yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg)


glu_opinfo = OpInfo(
ltorch.glu,
sample_input_generator=glu_sample_generator,
error_input_generator=glu_error_generator,
dtypes=(datatypes.inexact,),
torch_reference=torch.nn.functional.glu,
test_directives=(),
)
reduction_ops.append(glu_opinfo)


mean_opinfo = OpInfo(
ltorch.mean,
sample_input_generator=reduction_sample_generator,
torch_reference=torch.mean,
dtypes=(datatypes.floating, datatypes.complexfloating),
dtypes=(datatypes.inexact,),
test_directives=(
# PyTorch doesn't support CPU and CUDA complex half mean
DecorateInfo(
Expand Down
17 changes: 15 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike:


#
# Elementwise unary operaitons
# Elementwise unary operations
#
# TODO Add type annotations

Expand Down Expand Up @@ -2687,8 +2687,21 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy
register_method("clone", clone)


@torchsymbol(torch.nn.functional.glu, is_method=False)
def glu(a: TensorProxy, /, dim: int = -1) -> TensorProxy:
dim = utils.canonicalize_dim(len(a.shape), dim)
utils.check(
a.shape[dim] % 2 == 0,
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
)
chunk_size = a.shape[dim] // 2
left, right = split(a, (chunk_size, chunk_size), dim=dim)
out = left * sigmoid(right)
return out


@torchsymbol(torch.mean, is_method=True)
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None):
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy:
dtype = dtype if dtype is not None else a.dtype
utils.check(
not utils.is_integer_dtype(dtype) and not utils.is_boolean_dtype(dtype),
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@
torch.nn.functional.fractional_max_pool3d,
torch.nn.functional.fractional_max_pool3d_with_indices,
torch.nn.functional.gaussian_nll_loss,
torch.nn.functional.glu,
torch.nn.functional.grid_sample,
torch.nn.functional.gumbel_softmax,
torch.nn.functional.hardshrink,
Expand Down

0 comments on commit 2228304

Please sign in to comment.