Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix .mT, add tests #68

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from types import EllipsisType, NoneType
import copy
import time
import warnings

from thunder.core.baseutils import run_once
from thunder.core.compile_data import using_symbolic_values
from thunder.clang.langctx import register_method
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -1160,6 +1162,14 @@ def compute_broadcast_shape(*_shapes):
return tuple(common_shape)


@run_once
def mT_scalar_warning():
warnings.warn(
"Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.",
UserWarning,
)


@clangop(method_name="mT")
def matrix_transpose(a: TensorProxy) -> TensorProxy:
"""Transposes the last two dimensions of a tensor.
Expand All @@ -1181,6 +1191,13 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:
[2, 5],
[3, 6]])
"""

if a.ndim == 0:
mT_scalar_warning()
return a
elif a.ndim == 1:
raise RuntimeError(f"tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor.")

dim0, dim1 = -2, -1
dim0, dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1))
permutation = list(range(a.ndim))
Expand Down
33 changes: 25 additions & 8 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3975,17 +3975,30 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs

# shape
cases = (
(4, 7, 8),
(4, 7),
(),
(2, 3),
(2, 3, 4),
(2, 3, 4, 2),
)

for shape in cases:
yield SampleInput(make(shape))


def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

# shape, error type, error message
cases = (((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),)

for shape, err_type, err_msg in cases:
yield SampleInput(make(shape)), err_type, err_msg


transpose_opinfo = OpInfo(
clang.matrix_transpose,
sample_input_generator=matrix_transpose_sample_generator,
error_input_generator=matrix_transpose_error_generator,
torch_reference=lambda x: x.mT,
)
shape_ops.append(transpose_opinfo)
Expand Down Expand Up @@ -6824,9 +6837,11 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs
C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
yield SampleInput(
make(shape[0]),
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True),
(
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True)
),
weight=make(C) if weight_flag else None,
ignore_index=ignore_index,
reduction=reduction_str,
Expand Down Expand Up @@ -6870,9 +6885,11 @@ def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs):
C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
yield SampleInput(
make(shape[0]),
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True),
(
make(shape[1], low=0, high=C, dtype=torch.long, requires_grad=False)
if not probability_target
else make(shape[1], low=0.0, high=1.0, requires_grad=True)
),
weight=make(C) if weight_flag else None,
ignore_index=ignore_index,
reduction=reduction_str,
Expand Down
Loading