From cce2d2e54d03f4be22057e59bc5ee7cba5ff2206 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Sun, 24 Mar 2024 23:15:52 -0400 Subject: [PATCH] Fix .mT, add tests. --- thunder/clang/__init__.py | 17 +++++++++++++++++ thunder/tests/opinfos.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 3fd192802f..e1bdd0bebb 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -9,7 +9,9 @@ from types import EllipsisType, NoneType import copy import time +import warnings +from thunder.core.baseutils import run_once from thunder.core.compile_data import using_symbolic_values from thunder.clang.langctx import register_method from thunder.core.langctxs import langctx, Languages @@ -1160,6 +1162,14 @@ def compute_broadcast_shape(*_shapes): return tuple(common_shape) +@run_once +def mT_scalar_warning(): + warnings.warn( + "Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.", + UserWarning, + ) + + @clangop(method_name="mT") def matrix_transpose(a: TensorProxy) -> TensorProxy: """Transposes the last two dimensions of a tensor. @@ -1181,6 +1191,13 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy: [2, 5], [3, 6]]) """ + + if a.ndim == 0: + mT_scalar_warning() + return a + elif a.ndim == 1: + raise RuntimeError(f"tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor.") + dim0, dim1 = -2, -1 dim0, dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) permutation = list(range(a.ndim)) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 17af749d37..0d5d2cb0ab 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3975,17 +3975,30 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs # shape cases = ( - (4, 7, 8), - (4, 7), + (), + (2, 3), + (2, 3, 4), + (2, 3, 4, 2), ) for shape in cases: yield SampleInput(make(shape)) +def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + + # shape, error type, error message + cases = (((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),) + + for shape, err_type, err_msg in cases: + yield SampleInput(make(shape)), err_type, err_msg + + transpose_opinfo = OpInfo( clang.matrix_transpose, sample_input_generator=matrix_transpose_sample_generator, + error_input_generator=matrix_transpose_error_generator, torch_reference=lambda x: x.mT, ) shape_ops.append(transpose_opinfo) @@ -6824,9 +6837,11 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs C = input_shape[1] if len(input_shape) >= 2 else input_shape[0] yield SampleInput( make(shape[0]), - make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False) - if not probability_target - else make(shape[1], low=0.0, high=1.0, requires_grad=True), + ( + make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False) + if not probability_target + else make(shape[1], low=0.0, high=1.0, requires_grad=True) + ), weight=make(C) if weight_flag else None, ignore_index=ignore_index, reduction=reduction_str, @@ -6870,9 +6885,11 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): C = input_shape[1] if len(input_shape) >= 2 else input_shape[0] yield SampleInput( make(shape[0]), - make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False) - if not probability_target - else make(shape[1], low=0.0, high=1.0, requires_grad=True), + ( + make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False) + if not probability_target + else make(shape[1], low=0.0, high=1.0, requires_grad=True) + ), weight=make(C) if weight_flag else None, ignore_index=ignore_index, reduction=reduction_str,