Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add glu op #1477

Merged
merged 12 commits into from
Dec 3, 2024
23 changes: 8 additions & 15 deletions thunder/extend/__init__.py
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
from collections.abc import Callable, Sequence, Hashable
import enum
import sys
import os
import itertools
from typing import Any
from collections.abc import Sequence
from collections.abc import Callable
from collections.abc import Hashable
import os
import sys
from types import ModuleType
import warnings
from functools import cache, partial
from typing import Any

import torch.cuda


from thunder.core.utils import check
from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer
from thunder.core.trace import TraceCtx
from thunder.core.proxies import Proxy, TensorProxy, proxy
from thunder.core.baseutils import run_once
from thunder.core.dtypes import to_torch_dtype
from thunder.core.devices import to_torch_device
from thunder.core.dtypes import to_torch_dtype
from thunder.core.proxies import Proxy, TensorProxy, proxy
from thunder.core.pytree import tree_map
from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer
from thunder.core.trace import TraceCtx

__all__ = [
"register_executor",
Expand Down
67 changes: 56 additions & 11 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import itertools
import math
import operator
from collections import namedtuple
from collections.abc import Sequence
from collections.abc import Callable, Generator, Iterable, Sequence
from functools import partial, wraps
import itertools
import math
from numbers import Number
from typing import Union, Optional, Tuple, Any
from collections.abc import Callable
from collections.abc import Generator, Iterable
import operator
from typing import Any

import numpy as np
import pytest
Expand All @@ -23,14 +21,13 @@
import thunder.core.dtypes as datatypes
from thunder.core.dtypes import to_dtype, to_torch_dtype
import thunder.core.prims as prims
import thunder.executors as executors
import thunder.torch as ltorch
from thunder.core.pytree import tree_map
from thunder.core.symbol import Symbol
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS, version_between
import thunder.executors as executors
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS
from thunder.tests.make_tensor import make_tensor
import thunder.extend as extend
import thunder.tests.bf16
import thunder.torch as ltorch

#
# Helpful constants and utility functions
Expand Down Expand Up @@ -5491,6 +5488,54 @@ def var_sample_generator(op, device, dtype, requires_grad):
yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad))


def glu_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
)

cases = (
((4,),),
((3, 4), 1),
((3, 4),),
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
((4, 3), 0),
((4, 5, 8),),
((4, 5, 8), 0),
)

for case in cases:
if len(case) == 1:
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
yield SampleInput(make(*case))
else:
shape, dim = case
yield SampleInput(make(shape), dim)


def glu_error_generator(op, device, **kwargs):
make = partial(
make_tensor,
device=device,
dtype=torch.float,
)
err_msg = r"Halving dimension must be even, but dimension .* is size .*"
yield (SampleInput(make((3,))), RuntimeError, err_msg)
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg)
yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg)


glu_opinfo = OpInfo(
ltorch.glu,
sample_input_generator=glu_sample_generator,
error_input_generator=glu_error_generator,
dtypes=(datatypes.floating, datatypes.complexfloating),
torch_reference=torch.nn.functional.glu,
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
test_directives=(),
)
reduction_ops.append(glu_opinfo)

beverlylytle marked this conversation as resolved.
Show resolved Hide resolved

mean_opinfo = OpInfo(
ltorch.mean,
sample_input_generator=reduction_sample_generator,
Expand Down
15 changes: 14 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike:


#
# Elementwise unary operaitons
# Elementwise unary operations
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
#
# TODO Add type annotations

Expand Down Expand Up @@ -2687,6 +2687,19 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy
register_method("clone", clone)


@torchsymbol(torch.nn.functional.glu, is_method=False)
def glu(a: TensorProxy, /, dim: int = -1):
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
dim = -1 if dim is None else dim
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
utils.check(
a.shape[dim] % 2 == 0,
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
)
chunk_size = a.shape[dim] // 2
left, right = split(a, (chunk_size, chunk_size), dim=dim)
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
out = left * sigmoid(right)
return out


@torchsymbol(torch.mean, is_method=True)
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None):
dtype = dtype if dtype is not None else a.dtype
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@
torch.nn.functional.fractional_max_pool3d,
torch.nn.functional.fractional_max_pool3d_with_indices,
torch.nn.functional.gaussian_nll_loss,
torch.nn.functional.glu,
torch.nn.functional.grid_sample,
torch.nn.functional.gumbel_softmax,
torch.nn.functional.hardshrink,
Expand Down
Loading