diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 9385eb3970..0ee80f9cf3 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -32,7 +32,6 @@ import thunder.core.utils as utils import thunder.torch as ltorch -from thunder.torch import DeviceLike, dtypeLike, TensorLike from thunder.extend import OperatorExecutor, register_executor, add_always_executor diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 73181415c6..aca7cbdad7 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -1,26 +1,19 @@ +from collections.abc import Callable, Sequence, Hashable import enum -import sys -import os import itertools -from typing import Any -from collections.abc import Sequence -from collections.abc import Callable -from collections.abc import Hashable +import os +import sys from types import ModuleType -import warnings -from functools import cache, partial +from typing import Any import torch.cuda - -from thunder.core.utils import check -from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer -from thunder.core.trace import TraceCtx -from thunder.core.proxies import Proxy, TensorProxy, proxy -from thunder.core.baseutils import run_once -from thunder.core.dtypes import to_torch_dtype from thunder.core.devices import to_torch_device +from thunder.core.dtypes import to_torch_dtype +from thunder.core.proxies import Proxy, TensorProxy, proxy from thunder.core.pytree import tree_map +from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer +from thunder.core.trace import TraceCtx __all__ = [ "register_executor", diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 5cc5c89077..159e1ce3d9 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1,13 +1,11 @@ -import itertools -import math -import operator from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Callable, Generator, Iterable, Sequence from functools import partial, wraps +import itertools +import math from numbers import Number -from typing import Union, Optional, Tuple, Any -from collections.abc import Callable -from collections.abc import Generator, Iterable +import operator +from typing import Any import numpy as np import pytest @@ -23,14 +21,13 @@ import thunder.core.dtypes as datatypes from thunder.core.dtypes import to_dtype, to_torch_dtype import thunder.core.prims as prims -import thunder.executors as executors -import thunder.torch as ltorch from thunder.core.pytree import tree_map from thunder.core.symbol import Symbol -from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS, version_between +import thunder.executors as executors +from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS from thunder.tests.make_tensor import make_tensor -import thunder.extend as extend import thunder.tests.bf16 +import thunder.torch as ltorch # # Helpful constants and utility functions @@ -5491,11 +5488,62 @@ def var_sample_generator(op, device, dtype, requires_grad): yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad)) +# glu requires that value of the shape of the input at index dim be even +def glu_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + cases = ( + ((4,), None), + ((3, 4), 1), + ((3, 4), None), + ((3, 4), -1), + ((4, 3), 0), + ((4, 5, 8), None), + ((4, 5, 8), 0), + ) + + for shape, dim in cases: + if dim is None: + yield SampleInput(make(*shape)) + else: + yield SampleInput(make(*shape), dim) + + +def glu_error_generator(op, device, **kwargs): + make = partial( + make_tensor, + device=device, + dtype=torch.float, + ) + err_msg = r"Halving dimension must be even, but dimension .* is size .*" + # The value of the shape of the input in the default (last) dim is odd, which is unsupported. + yield (SampleInput(make((3,))), RuntimeError, err_msg) + yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg) + # The value of the shape of the input at index dim=1 is odd, which is unsupported. + yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg) + + +glu_opinfo = OpInfo( + ltorch.glu, + sample_input_generator=glu_sample_generator, + error_input_generator=glu_error_generator, + dtypes=(datatypes.inexact,), + torch_reference=torch.nn.functional.glu, + test_directives=(), +) +reduction_ops.append(glu_opinfo) + + mean_opinfo = OpInfo( ltorch.mean, sample_input_generator=reduction_sample_generator, torch_reference=torch.mean, - dtypes=(datatypes.floating, datatypes.complexfloating), + dtypes=(datatypes.inexact,), test_directives=( # PyTorch doesn't support CPU and CUDA complex half mean DecorateInfo( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 6327b4d05c..1dd272ee9f 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1345,7 +1345,7 @@ def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike: # -# Elementwise unary operaitons +# Elementwise unary operations # # TODO Add type annotations @@ -2687,8 +2687,21 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy register_method("clone", clone) +@torchsymbol(torch.nn.functional.glu, is_method=False) +def glu(a: TensorProxy, /, dim: int = -1) -> TensorProxy: + dim = utils.canonicalize_dim(len(a.shape), dim) + utils.check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + chunk_size = a.shape[dim] // 2 + left, right = split(a, (chunk_size, chunk_size), dim=dim) + out = left * sigmoid(right) + return out + + @torchsymbol(torch.mean, is_method=True) -def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None): +def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy: dtype = dtype if dtype is not None else a.dtype utils.check( not utils.is_integer_dtype(dtype) and not utils.is_boolean_dtype(dtype), diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index 91ea98adf0..a43b6c3131 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -346,7 +346,6 @@ torch.nn.functional.fractional_max_pool3d, torch.nn.functional.fractional_max_pool3d_with_indices, torch.nn.functional.gaussian_nll_loss, - torch.nn.functional.glu, torch.nn.functional.grid_sample, torch.nn.functional.gumbel_softmax, torch.nn.functional.hardshrink,