Skip to content

Commit

Permalink
Fix auto-registered torch.special operators (#979)
Browse files Browse the repository at this point in the history
Background:
This PR addresses a bug related to the handling of `torch.special` operators, discovered during the development of [PR #976](#976).

`torch.special` operators has `__name__` in the format `special_opname`, requiring extraction of the actual `opname`. Similar issues occur with `torch.linalg` and `torch.fft` operators.

In this PR:

- Add function `_get_torch_function_name` to infer the python call name from the torch module and function
- Add support for auto-registration of `torch.linalg` and `torch.fft` operators and the tests
  • Loading branch information
kiya00 authored Aug 21, 2024
1 parent 3548ba8 commit d4d9786
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 30 deletions.
61 changes: 38 additions & 23 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
from functools import partial
from unittest.mock import patch
from itertools import islice

import pytest
import thunder
import thunder.torch.default_torch_ops as ops
from thunder.torch import _get_torch_function_name
import torch

from thunder.tests.framework import requiresCUDA, TorchExecutor
from thunder.tests.make_tensor import make_tensor
from thunder.tests.opinfos import get_opinfo, OpInfo
from thunder.tests.test_einops import skipIfNoCUDA
from torch.testing._internal.common_device_type import skipCPUIfNoLapack
from torch.testing._internal.common_device_type import skipCPUIfNoLapack, skipCUDAIfNoMagma
from torch.testing._internal.common_methods_invocations import op_db

_name2func = {}
[_name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]]
[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]]
# NOTE some `opinfo.name`s don't have torch.xxx but are only used as torch.Tensor.xxx
torch_tensor_space_names = [v.__name__ for v in ops.torch_auto_registered_ops[torch.Tensor]]
_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func or opinfo.name in torch_tensor_space_names]
for m, fns in ops.torch_auto_registered_ops.items():
m_name = m.__name__[len("torch.") :] if m.__name__.startswith("torch.") else m.__name__
for fn in fns:
_name2func.setdefault(f"{m_name}.{_get_torch_function_name(m, fn)}", fn)


def get_opinfos_for_test():
opinfos = []
for opinfo in op_db:
if (
opinfo.name in _name2func
or f"Tensor.{opinfo.name}" in _name2func
or any(alias.name in _name2func or f"Tensor.{alias.name}" in _name2func for alias in opinfo.aliases)
):
opinfos.append(opinfo)

return opinfos


_opinfos = get_opinfos_for_test()


# Note that successfully catching an exception in this test is also noted as passed
Expand All @@ -29,18 +44,18 @@
@pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference"))
@pytest.mark.parametrize("device,", ["cuda", "cpu"])
def test_torch_ops_trace(device, requires_grad, op_info):
from itertools import islice

# for op_info in op_infos:
if not op_info.supports_autograd and requires_grad:
pytest.skip("op_info.supports_autograd is False")
if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA:
return
pytest.skip("float32 is not in op_info.dtypesIfCUDA")
if device == "cpu" and not torch.float32 in op_info.dtypes:
return
# No cuda backend support
pytest.skip("float32 is not in op_info.dtypes")
if op_info.name in ("nonzero_static",) and device == "cuda":
return
pytest.skip("Could not run 'aten::nonzero_static' with arguments from the 'CUDA' backend.")
if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators:
return
pytest.skip("PyTorch compiled without Lapack")
if device == "cuda" and not torch.cuda.has_magma and skipCUDAIfNoMagma in op_info.decorators:
pytest.skip("PyTorch compiled without Magma")

def get_method(op_info):
# Check if we have registered this method.
Expand All @@ -51,7 +66,8 @@ def get_method(op_info):
return None

funcs = [_name2func.get(op_info.name, None), get_method(op_info)]
for func in funcs:
funcs.extend(_name2func.get(alias.name, None) for alias in op_info.aliases)
for idx, func in enumerate(funcs):
if func is None:
continue
# It takes too long, test only the first 5 sample inputs
Expand All @@ -72,6 +88,8 @@ def get_method(op_info):
).startswith(f"Unsupported type:")
break
else:
# Get the alias name when testing for alias
cur_op_name = op_info.name if idx < 2 else op_info.aliases[idx - 2].name
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
Expand All @@ -80,15 +98,15 @@ def get_method(op_info):
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
vjp_op_name = f"{cur_op_name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
bsym.sym.name.endswith(cur_op_name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)

Expand Down Expand Up @@ -148,10 +166,7 @@ def test_nanogpt_block(self):
import thunder.tests.nanogpt_model as nanogpt_model

for op in _skip_ops_nanogpt:
if op.name == "gelu":
register_default_torch_op(op.torch_reference, torch)
else:
register_default_torch_op(op.torch_reference, torch.nn.functional)
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
device = torch.device("cuda")
Expand Down
34 changes: 29 additions & 5 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import partial, reduce, wraps
from numbers import Number
from typing import Any, overload
from types import NoneType
from types import NoneType, ModuleType
from collections.abc import Callable

import opt_einsum
Expand Down Expand Up @@ -5422,18 +5422,42 @@ def register_default_torch_ops():
register_default_torch_op(fn, m)


def _get_torch_function_name(torch_module: ModuleType, torchfn: Callable):
# Handle special cases where torchfn.__name__ differs from the name used to call it in Python,
# e.g., `torch.nn.functional.logsigmoid.__name__` is 'log_sigmoid'.
special_cases = {torch.nn.functional.logsigmoid: "logsigmoid"}
# Operators in the following namespace have an extra prefix in their __name__ attribute compared to their Python call name.
# e.g., `torch.special.xlogy.__name__` is 'special_xlogy' instead of just 'xlogy'.
name_prefix_map = {torch.special: "special_", torch.linalg: "linalg_", torch.fft: "fft_"}
if function_name := special_cases.get(torchfn, None):
return function_name
if not (function_name := getattr(torchfn, "__name__", None)):
raise RuntimeError(
f"The function {torchfn} from the module {torch_module} does not have a __name__ attribute. Please ensure that you are passing a valid PyTorch function."
)
prefix = name_prefix_map.get(torch_module, None)
if prefix and function_name.startswith(prefix):
function_name = function_name[len(prefix) :]
utils.check(
getattr(torch_module, function_name, None),
lambda: f"Incorrect function name {function_name} inferred for PyTorch function {torchfn} from module {torch_module}.",
)
return function_name


def register_default_torch_op(torchfn: Callable, torch_module):
fn_meta = meta_adaptor(torchfn)
_fn = langctx(Languages.TORCH)(fn_meta)
torchfn_name = _get_torch_function_name(torch_module, torchfn)
sym = Symbol(
name=torchfn.__name__,
name=torchfn_name,
meta=_fn,
id=f"{torch_module.__name__}.{torchfn.__name__}",
id=f"{torch_module.__name__}.{torchfn_name}",
)
_torch_to_thunder_function_map[torchfn] = sym
from thunder.executors.torchex import _always_executable, ex

op = ex.register_operator(torchfn.__name__, module=torch_module, meta=fn_meta)
op = ex.register_operator(torchfn_name, module=torch_module, meta=fn_meta)
ex.register_implementation(sym, op, checker=_always_executable)

from thunder.core.transforms import augmented_forward_impls, backward_impls
Expand All @@ -5447,7 +5471,7 @@ def register_default_torch_op(torchfn: Callable, torch_module):

_vjp_impl_wrapper = partial(_vjp_impl, torchfn)

bwd_op = ex.register_operator(torchfn.__name__ + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper)
bwd_op = ex.register_operator(torchfn_name + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper)
ex.register_implementation(bwd_op.id, bwd_op, checker=_always_executable)
backward_impls[sym.id] = bwd_op

Expand Down
67 changes: 65 additions & 2 deletions thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@
torch.nn.functional.unfold,
],
torch.Tensor: [
torch.Tensor.positive,
torch.Tensor.absolute,
torch.Tensor.addbmm,
torch.Tensor.addmm,
Expand Down Expand Up @@ -637,7 +636,6 @@
torch.Tensor.vdot,
torch.Tensor.vsplit,
torch.Tensor.xlogy,
torch.Tensor.xpu,
],
torch.special: [
torch.special.airy_ai,
Expand Down Expand Up @@ -692,4 +690,69 @@
torch.special.xlog1py,
torch.special.xlogy,
],
torch.linalg: [
torch.linalg.cholesky,
torch.linalg.cholesky_ex,
torch.linalg.cond,
torch.linalg.cross,
torch.linalg.det,
torch.linalg.diagonal,
torch.linalg.eig,
torch.linalg.eigh,
torch.linalg.eigvals,
torch.linalg.eigvalsh,
torch.linalg.householder_product,
torch.linalg.inv,
torch.linalg.inv_ex,
torch.linalg.ldl_factor,
torch.linalg.ldl_factor_ex,
torch.linalg.ldl_solve,
torch.linalg.lstsq,
torch.linalg.lu,
torch.linalg.lu_factor,
torch.linalg.lu_factor_ex,
torch.linalg.lu_solve,
torch.linalg.matmul,
torch.linalg.matrix_exp,
torch.linalg.matrix_norm,
torch.linalg.matrix_power,
torch.linalg.matrix_rank,
torch.linalg.multi_dot,
torch.linalg.norm,
torch.linalg.pinv,
torch.linalg.qr,
torch.linalg.slogdet,
torch.linalg.solve,
torch.linalg.solve_ex,
torch.linalg.solve_triangular,
torch.linalg.svd,
torch.linalg.svdvals,
torch.linalg.tensorinv,
torch.linalg.tensorsolve,
torch.linalg.vander,
torch.linalg.vecdot,
torch.linalg.vector_norm,
],
torch.fft: [
torch.fft.fft,
torch.fft.fft2,
torch.fft.fftn,
torch.fft.fftshift,
torch.fft.hfft,
torch.fft.hfft2,
torch.fft.hfftn,
torch.fft.ifft,
torch.fft.ifft2,
torch.fft.ifftn,
torch.fft.ifftshift,
torch.fft.ihfft,
torch.fft.ihfft2,
torch.fft.ihfftn,
torch.fft.irfft,
torch.fft.irfft2,
torch.fft.irfftn,
torch.fft.rfft,
torch.fft.rfft2,
torch.fft.rfftn,
],
}

0 comments on commit d4d9786

Please sign in to comment.