Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auto-registered torch.special operators #979

Merged
merged 7 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
from functools import partial
from unittest.mock import patch
from itertools import islice

import pytest
import thunder
import thunder.torch.default_torch_ops as ops
from thunder.torch import _get_torch_function_name
import torch

from thunder.tests.framework import requiresCUDA, TorchExecutor
from thunder.tests.make_tensor import make_tensor
from thunder.tests.opinfos import get_opinfo, OpInfo
from thunder.tests.test_einops import skipIfNoCUDA
from torch.testing._internal.common_device_type import skipCPUIfNoLapack
from torch.testing._internal.common_device_type import skipCPUIfNoLapack, skipCUDAIfNoMagma
from torch.testing._internal.common_methods_invocations import op_db

_name2func = {}
[_name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]]
[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]]
# NOTE some `opinfo.name`s don't have torch.xxx but are only used as torch.Tensor.xxx
torch_tensor_space_names = [v.__name__ for v in ops.torch_auto_registered_ops[torch.Tensor]]
_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func or opinfo.name in torch_tensor_space_names]
for m, fns in ops.torch_auto_registered_ops.items():
m_name = m.__name__[len("torch.") :] if m.__name__.startswith("torch.") else m.__name__
for fn in fns:
_name2func.setdefault(f"{m_name}.{_get_torch_function_name(m, fn)}", fn)


def get_opinfos_for_test():
opinfos = []
for opinfo in op_db:
if (
opinfo.name in _name2func
or f"Tensor.{opinfo.name}" in _name2func
or any(alias.name in _name2func or f"Tensor.{alias.name}" in _name2func for alias in opinfo.aliases)
):
opinfos.append(opinfo)

return opinfos


_opinfos = get_opinfos_for_test()


# Note that successfully catching an exception in this test is also noted as passed
Expand All @@ -29,18 +44,18 @@
@pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference"))
@pytest.mark.parametrize("device,", ["cuda", "cpu"])
def test_torch_ops_trace(device, requires_grad, op_info):
from itertools import islice

# for op_info in op_infos:
if not op_info.supports_autograd and requires_grad:
pytest.skip("op_info.supports_autograd is False")
if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA:
return
pytest.skip("float32 is not in op_info.dtypesIfCUDA")
if device == "cpu" and not torch.float32 in op_info.dtypes:
return
# No cuda backend support
pytest.skip("float32 is not in op_info.dtypes")
if op_info.name in ("nonzero_static",) and device == "cuda":
return
pytest.skip("Could not run 'aten::nonzero_static' with arguments from the 'CUDA' backend.")
if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators:
return
pytest.skip("PyTorch compiled without Lapack")
if device == "cuda" and not torch.cuda.has_magma and skipCUDAIfNoMagma in op_info.decorators:
pytest.skip("PyTorch compiled without Magma")

def get_method(op_info):
# Check if we have registered this method.
Expand All @@ -51,7 +66,8 @@ def get_method(op_info):
return None

funcs = [_name2func.get(op_info.name, None), get_method(op_info)]
for func in funcs:
funcs.extend(_name2func.get(alias.name, None) for alias in op_info.aliases)
for idx, func in enumerate(funcs):
if func is None:
continue
# It takes too long, test only the first 5 sample inputs
Expand All @@ -72,6 +88,8 @@ def get_method(op_info):
).startswith(f"Unsupported type:")
break
else:
# Get the alias name when testing for alias
cur_op_name = op_info.name if idx < 2 else op_info.aliases[idx - 2].name
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
Expand All @@ -80,15 +98,15 @@ def get_method(op_info):
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
vjp_op_name = f"{cur_op_name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
bsym.sym.name.endswith(cur_op_name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)

Expand Down Expand Up @@ -148,10 +166,7 @@ def test_nanogpt_block(self):
import thunder.tests.nanogpt_model as nanogpt_model

for op in _skip_ops_nanogpt:
if op.name == "gelu":
register_default_torch_op(op.torch_reference, torch)
else:
register_default_torch_op(op.torch_reference, torch.nn.functional)
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
device = torch.device("cuda")
Expand Down
34 changes: 29 additions & 5 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import partial, reduce, wraps
from numbers import Number
from typing import Any, overload
from types import NoneType
from types import NoneType, ModuleType
from collections.abc import Callable

import opt_einsum
Expand Down Expand Up @@ -5422,18 +5422,42 @@ def register_default_torch_ops():
register_default_torch_op(fn, m)


def _get_torch_function_name(torch_module: ModuleType, torchfn: Callable):
# Handle special cases where torchfn.__name__ differs from the name used to call it in Python,
# e.g., `torch.nn.functional.logsigmoid.__name__` is 'log_sigmoid'.
special_cases = {torch.nn.functional.logsigmoid: "logsigmoid"}
# Operators in the following namespace have an extra prefix in their __name__ attribute compared to their Python call name.
# e.g., `torch.special.xlogy.__name__` is 'special_xlogy' instead of just 'xlogy'.
name_prefix_map = {torch.special: "special_", torch.linalg: "linalg_", torch.fft: "fft_"}
if function_name := special_cases.get(torchfn, None):
return function_name
if not (function_name := getattr(torchfn, "__name__", None)):
raise RuntimeError(
f"The function {torchfn} from the module {torch_module} does not have a __name__ attribute. Please ensure that you are passing a valid PyTorch function."
)
prefix = name_prefix_map.get(torch_module, None)
if prefix and function_name.startswith(prefix):
function_name = function_name[len(prefix) :]
utils.check(
getattr(torch_module, function_name, None),
lambda: f"Incorrect function name {function_name} inferred for PyTorch function {torchfn} from module {torch_module}.",
)
return function_name


def register_default_torch_op(torchfn: Callable, torch_module):
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
fn_meta = meta_adaptor(torchfn)
_fn = langctx(Languages.TORCH)(fn_meta)
torchfn_name = _get_torch_function_name(torch_module, torchfn)
sym = Symbol(
name=torchfn.__name__,
name=torchfn_name,
meta=_fn,
id=f"{torch_module.__name__}.{torchfn.__name__}",
id=f"{torch_module.__name__}.{torchfn_name}",
)
_torch_to_thunder_function_map[torchfn] = sym
from thunder.executors.torchex import _always_executable, ex

op = ex.register_operator(torchfn.__name__, module=torch_module, meta=fn_meta)
op = ex.register_operator(torchfn_name, module=torch_module, meta=fn_meta)
ex.register_implementation(sym, op, checker=_always_executable)

from thunder.core.transforms import augmented_forward_impls, backward_impls
Expand All @@ -5447,7 +5471,7 @@ def register_default_torch_op(torchfn: Callable, torch_module):

_vjp_impl_wrapper = partial(_vjp_impl, torchfn)

bwd_op = ex.register_operator(torchfn.__name__ + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper)
bwd_op = ex.register_operator(torchfn_name + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper)
ex.register_implementation(bwd_op.id, bwd_op, checker=_always_executable)
backward_impls[sym.id] = bwd_op

Expand Down
67 changes: 65 additions & 2 deletions thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@
torch.nn.functional.unfold,
],
torch.Tensor: [
torch.Tensor.positive,
torch.Tensor.absolute,
torch.Tensor.addbmm,
torch.Tensor.addmm,
Expand Down Expand Up @@ -637,7 +636,6 @@
torch.Tensor.vdot,
torch.Tensor.vsplit,
torch.Tensor.xlogy,
torch.Tensor.xpu,
],
torch.special: [
torch.special.airy_ai,
Expand Down Expand Up @@ -692,4 +690,69 @@
torch.special.xlog1py,
torch.special.xlogy,
],
torch.linalg: [
torch.linalg.cholesky,
torch.linalg.cholesky_ex,
torch.linalg.cond,
torch.linalg.cross,
torch.linalg.det,
torch.linalg.diagonal,
torch.linalg.eig,
torch.linalg.eigh,
torch.linalg.eigvals,
torch.linalg.eigvalsh,
torch.linalg.householder_product,
torch.linalg.inv,
torch.linalg.inv_ex,
torch.linalg.ldl_factor,
torch.linalg.ldl_factor_ex,
torch.linalg.ldl_solve,
torch.linalg.lstsq,
torch.linalg.lu,
torch.linalg.lu_factor,
torch.linalg.lu_factor_ex,
torch.linalg.lu_solve,
torch.linalg.matmul,
torch.linalg.matrix_exp,
torch.linalg.matrix_norm,
torch.linalg.matrix_power,
torch.linalg.matrix_rank,
torch.linalg.multi_dot,
torch.linalg.norm,
torch.linalg.pinv,
torch.linalg.qr,
torch.linalg.slogdet,
torch.linalg.solve,
torch.linalg.solve_ex,
torch.linalg.solve_triangular,
torch.linalg.svd,
torch.linalg.svdvals,
torch.linalg.tensorinv,
torch.linalg.tensorsolve,
torch.linalg.vander,
torch.linalg.vecdot,
torch.linalg.vector_norm,
],
torch.fft: [
torch.fft.fft,
torch.fft.fft2,
torch.fft.fftn,
torch.fft.fftshift,
torch.fft.hfft,
torch.fft.hfft2,
torch.fft.hfftn,
torch.fft.ifft,
torch.fft.ifft2,
torch.fft.ifftn,
torch.fft.ifftshift,
torch.fft.ihfft,
torch.fft.ihfft2,
torch.fft.ihfftn,
torch.fft.irfft,
torch.fft.irfft2,
torch.fft.irfftn,
torch.fft.rfft,
torch.fft.rfft2,
torch.fft.rfftn,
],
}
Loading