From c95487deaeeb1b5133075d2b0c8cb02b5c1cd10d Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Thu, 21 Nov 2024 13:48:48 +0200 Subject: [PATCH 1/9] add glu op --- thunder/tests/opinfos.py | 48 ++++++++++++++++++++++++++++++ thunder/torch/__init__.py | 15 +++++++++- thunder/torch/default_torch_ops.py | 1 - 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index b8daca7de2..18c47430c1 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -5475,6 +5475,54 @@ def var_sample_generator(op, device, dtype, requires_grad): yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad)) +def glu_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + cases = ( + ((4,),), + ((3, 4), 1), + ((3, 4),), + ((4, 3), 0), + ((4, 5, 8),), + ((4, 5, 8), 0), + ) + + for case in cases: + if len(case) == 1: + yield SampleInput(make(*case)) + else: + shape, dim = case + yield SampleInput(make(shape), dim) + + +def glu_error_generator(op, device, **kwargs): + make = partial( + make_tensor, + device=device, + dtype=torch.float, + ) + err_msg = r"Halving dimension must be even, but dimension .* is size .*" + yield (SampleInput(make((3,))), RuntimeError, err_msg) + yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg) + yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg) + + +glu_opinfo = OpInfo( + ltorch.glu, + sample_input_generator=glu_sample_generator, + error_input_generator=glu_error_generator, + dtypes=(datatypes.floating, datatypes.complexfloating), + torch_reference=torch.nn.functional.glu, + test_directives=(), +) +reduction_ops.append(glu_opinfo) + + mean_opinfo = OpInfo( ltorch.mean, sample_input_generator=reduction_sample_generator, diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a94ada1cc8..168300f838 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1345,7 +1345,7 @@ def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike: # -# Elementwise unary operaitons +# Elementwise unary operatons # # TODO Add type annotations @@ -2676,6 +2676,19 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy register_method("clone", clone) +@torchsymbol(torch.nn.functional.glu, is_method=False) +def glu(a: TensorProxy, /, dim=None): + dim = -1 if dim is None else dim + utils.check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + chunk_size = a.shape[dim] // 2 + left, right = split(a, (chunk_size, chunk_size), dim=dim) + out = left * sigmoid(right) + return out + + @torchsymbol(torch.mean, is_method=True) def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None): dtype = dtype if dtype is not None else a.dtype diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index 84e0ae0f90..163680b9cb 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -346,7 +346,6 @@ torch.nn.functional.fractional_max_pool3d, torch.nn.functional.fractional_max_pool3d_with_indices, torch.nn.functional.gaussian_nll_loss, - torch.nn.functional.glu, torch.nn.functional.grid_sample, torch.nn.functional.gumbel_softmax, torch.nn.functional.hardshrink, From bb549922aafee2dfa89b34c142ede3853f33f731 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Thu, 21 Nov 2024 14:01:29 +0200 Subject: [PATCH 2/9] tidy imports --- thunder/tests/opinfos.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 18c47430c1..66f339ce4c 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1,13 +1,11 @@ -import itertools -import math -import operator from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Callable, Generator, Iterable, Sequence from functools import partial, wraps +import itertools +import math from numbers import Number -from typing import Union, Optional, Tuple, Any -from collections.abc import Callable -from collections.abc import Generator, Iterable +import operator +from typing import Any import numpy as np import pytest @@ -23,14 +21,13 @@ import thunder.core.dtypes as datatypes from thunder.core.dtypes import to_dtype, to_torch_dtype import thunder.core.prims as prims -import thunder.executors as executors -import thunder.torch as ltorch from thunder.core.pytree import tree_map from thunder.core.symbol import Symbol -from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS, version_between +import thunder.executors as executors +from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS from thunder.tests.make_tensor import make_tensor -import thunder.extend as extend import thunder.tests.bf16 +import thunder.torch as ltorch # # Helpful constants and utility functions From 48fdb683503e6420df6cb2dd40f0ce03105c957a Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Thu, 21 Nov 2024 14:06:29 +0200 Subject: [PATCH 3/9] tidy imports --- thunder/extend/__init__.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 73181415c6..ac903564a1 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -1,26 +1,18 @@ +from collections.abc import Callable, Sequence, Hashable import enum -import sys -import os import itertools -from typing import Any -from collections.abc import Sequence -from collections.abc import Callable -from collections.abc import Hashable -from types import ModuleType -import warnings -from functools import cache, partial +import os +import sys +from typing import Any, ModuleType import torch.cuda - -from thunder.core.utils import check -from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer -from thunder.core.trace import TraceCtx -from thunder.core.proxies import Proxy, TensorProxy, proxy -from thunder.core.baseutils import run_once -from thunder.core.dtypes import to_torch_dtype from thunder.core.devices import to_torch_device +from thunder.core.dtypes import to_torch_dtype +from thunder.core.proxies import Proxy, TensorProxy, proxy from thunder.core.pytree import tree_map +from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer +from thunder.core.trace import TraceCtx __all__ = [ "register_executor", From d7cb6e7ab1513ee221d273eb8c036b78ec9f92bc Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 26 Nov 2024 18:01:36 +0200 Subject: [PATCH 4/9] fix import --- thunder/extend/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index ac903564a1..aca7cbdad7 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -3,7 +3,8 @@ import itertools import os import sys -from typing import Any, ModuleType +from types import ModuleType +from typing import Any import torch.cuda From 3536c5d08f7cbd3c5f84bfe0e73a59f05d6ca062 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 26 Nov 2024 18:50:13 +0200 Subject: [PATCH 5/9] update type suggestion --- thunder/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 168300f838..bc975e7b6c 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -2677,7 +2677,7 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy @torchsymbol(torch.nn.functional.glu, is_method=False) -def glu(a: TensorProxy, /, dim=None): +def glu(a: TensorProxy, /, dim: None | int = None): dim = -1 if dim is None else dim utils.check( a.shape[dim] % 2 == 0, From 15194bb6e20452ae0ea10ec90debc5f9e86fda25 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 26 Nov 2024 18:51:29 +0200 Subject: [PATCH 6/9] update default value --- thunder/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index bc975e7b6c..b6cf2ee3ee 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -2677,7 +2677,7 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy @torchsymbol(torch.nn.functional.glu, is_method=False) -def glu(a: TensorProxy, /, dim: None | int = None): +def glu(a: TensorProxy, /, dim: int = -1): dim = -1 if dim is None else dim utils.check( a.shape[dim] % 2 == 0, From ac4a073e49cf84a6254b596c3059d44ca8e223f9 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 26 Nov 2024 18:53:21 +0200 Subject: [PATCH 7/9] typo --- thunder/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index b6cf2ee3ee..f503481503 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1345,7 +1345,7 @@ def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike: # -# Elementwise unary operatons +# Elementwise unary operations # # TODO Add type annotations From 8bf0513a7a0cb12d1f17292f8a3f520c18da8758 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Wed, 27 Nov 2024 13:28:26 +0200 Subject: [PATCH 8/9] fix dim check --- thunder/torch/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 9b5d80d9ca..c5d2ebc9ba 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -2689,7 +2689,11 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy @torchsymbol(torch.nn.functional.glu, is_method=False) def glu(a: TensorProxy, /, dim: int = -1): - dim = -1 if dim is None else dim + utils.check( + -a.ndim <= dim < a.ndim, + lambda: f"Dimension out of range (expected to be in range [{-a.ndim}, {a.ndim - 1}], but got {dim})", + IndexError, + ) utils.check( a.shape[dim] % 2 == 0, lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", From 7a6d89a4f2c5d8894fb0ab84cece587171fdada3 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 3 Dec 2024 16:46:43 +0200 Subject: [PATCH 9/9] respond to comments --- thunder/executors/torchex.py | 1 - thunder/tests/opinfos.py | 23 +++++++++++++---------- thunder/torch/__init__.py | 10 +++------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 9385eb3970..0ee80f9cf3 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -32,7 +32,6 @@ import thunder.core.utils as utils import thunder.torch as ltorch -from thunder.torch import DeviceLike, dtypeLike, TensorLike from thunder.extend import OperatorExecutor, register_executor, add_always_executor diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index c1facfed45..159e1ce3d9 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -5488,6 +5488,7 @@ def var_sample_generator(op, device, dtype, requires_grad): yield SampleInput(make_tensor((), device=device, dtype=dtype, requires_grad=requires_grad)) +# glu requires that value of the shape of the input at index dim be even def glu_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial( make_tensor, @@ -5497,20 +5498,20 @@ def glu_sample_generator(op, device, dtype, requires_grad, **kwargs): ) cases = ( - ((4,),), + ((4,), None), ((3, 4), 1), - ((3, 4),), + ((3, 4), None), + ((3, 4), -1), ((4, 3), 0), - ((4, 5, 8),), + ((4, 5, 8), None), ((4, 5, 8), 0), ) - for case in cases: - if len(case) == 1: - yield SampleInput(make(*case)) + for shape, dim in cases: + if dim is None: + yield SampleInput(make(*shape)) else: - shape, dim = case - yield SampleInput(make(shape), dim) + yield SampleInput(make(*shape), dim) def glu_error_generator(op, device, **kwargs): @@ -5520,8 +5521,10 @@ def glu_error_generator(op, device, **kwargs): dtype=torch.float, ) err_msg = r"Halving dimension must be even, but dimension .* is size .*" + # The value of the shape of the input in the default (last) dim is odd, which is unsupported. yield (SampleInput(make((3,))), RuntimeError, err_msg) yield (SampleInput(make((2, 2, 3))), RuntimeError, err_msg) + # The value of the shape of the input at index dim=1 is odd, which is unsupported. yield (SampleInput(make((4, 5, 8)), dim=1), RuntimeError, err_msg) @@ -5529,7 +5532,7 @@ def glu_error_generator(op, device, **kwargs): ltorch.glu, sample_input_generator=glu_sample_generator, error_input_generator=glu_error_generator, - dtypes=(datatypes.floating, datatypes.complexfloating), + dtypes=(datatypes.inexact,), torch_reference=torch.nn.functional.glu, test_directives=(), ) @@ -5540,7 +5543,7 @@ def glu_error_generator(op, device, **kwargs): ltorch.mean, sample_input_generator=reduction_sample_generator, torch_reference=torch.mean, - dtypes=(datatypes.floating, datatypes.complexfloating), + dtypes=(datatypes.inexact,), test_directives=( # PyTorch doesn't support CPU and CUDA complex half mean DecorateInfo( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index c5d2ebc9ba..87bcf32e9d 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -2688,12 +2688,8 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy @torchsymbol(torch.nn.functional.glu, is_method=False) -def glu(a: TensorProxy, /, dim: int = -1): - utils.check( - -a.ndim <= dim < a.ndim, - lambda: f"Dimension out of range (expected to be in range [{-a.ndim}, {a.ndim - 1}], but got {dim})", - IndexError, - ) +def glu(a: TensorProxy, /, dim: int = -1) -> TensorProxy: + dim = utils.canonicalize_dim(len(a.shape), dim) utils.check( a.shape[dim] % 2 == 0, lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", @@ -2705,7 +2701,7 @@ def glu(a: TensorProxy, /, dim: int = -1): @torchsymbol(torch.mean, is_method=True) -def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None): +def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy: dtype = dtype if dtype is not None else a.dtype utils.check( not utils.is_integer_dtype(dtype) and not utils.is_boolean_dtype(dtype),