Skip to content

Commit

Permalink
Fix .mT, add tests (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaz-cli authored Mar 25, 2024
1 parent 911ef8a commit 2e0bb61
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
17 changes: 17 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from types import EllipsisType, NoneType
import copy
import time
import warnings

from thunder.core.baseutils import run_once
from thunder.core.compile_data import using_symbolic_values
from thunder.clang.langctx import register_method
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -1160,6 +1162,14 @@ def compute_broadcast_shape(*_shapes):
return tuple(common_shape)


@run_once
def mT_scalar_warning():
warnings.warn(
"Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.",
UserWarning,
)


@clangop(method_name="mT")
def matrix_transpose(a: TensorProxy) -> TensorProxy:
"""Transposes the last two dimensions of a tensor.
Expand All @@ -1181,6 +1191,13 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:
[2, 5],
[3, 6]])
"""

if a.ndim == 0:
mT_scalar_warning()
return a
elif a.ndim == 1:
raise RuntimeError(f"tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor.")

dim0, dim1 = -2, -1
dim0, dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1))
permutation = list(range(a.ndim))
Expand Down
33 changes: 25 additions & 8 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3975,17 +3975,30 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs

# shape
cases = (
(4, 7, 8),
(4, 7),
(),
(2, 3),
(2, 3, 4),
(2, 3, 4, 2),
)

for shape in cases:
yield SampleInput(make(shape))


def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

# shape, error type, error message
cases = (((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),)

for shape, err_type, err_msg in cases:
yield SampleInput(make(shape)), err_type, err_msg


transpose_opinfo = OpInfo(
clang.matrix_transpose,
sample_input_generator=matrix_transpose_sample_generator,
error_input_generator=matrix_transpose_error_generator,
torch_reference=lambda x: x.mT,
)
shape_ops.append(transpose_opinfo)
Expand Down Expand Up @@ -6824,9 +6837,11 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs
C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
yield SampleInput(
make(shape[0]),
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True),
(
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True)
),
weight=make(C) if weight_flag else None,
ignore_index=ignore_index,
reduction=reduction_str,
Expand Down Expand Up @@ -6870,9 +6885,11 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs):
C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
yield SampleInput(
make(shape[0]),
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True),
(
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True)
),
weight=make(C) if weight_flag else None,
ignore_index=ignore_index,
reduction=reduction_str,
Expand Down

0 comments on commit 2e0bb61

Please sign in to comment.