From d4d9786ff670c64d009f0900e079bb014fa092d8 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 21 Aug 2024 10:08:59 +0200 Subject: [PATCH] Fix auto-registered torch.special operators (#979) Background: This PR addresses a bug related to the handling of `torch.special` operators, discovered during the development of [PR #976](https://github.com/Lightning-AI/lightning-thunder/pull/976). `torch.special` operators has `__name__` in the format `special_opname`, requiring extraction of the actual `opname`. Similar issues occur with `torch.linalg` and `torch.fft` operators. In this PR: - Add function `_get_torch_function_name` to infer the python call name from the torch module and function - Add support for auto-registration of `torch.linalg` and `torch.fft` operators and the tests --- thunder/tests/test_auto_register_torchops.py | 61 +++++++++++------- thunder/torch/__init__.py | 34 ++++++++-- thunder/torch/default_torch_ops.py | 67 +++++++++++++++++++- 3 files changed, 132 insertions(+), 30 deletions(-) diff --git a/thunder/tests/test_auto_register_torchops.py b/thunder/tests/test_auto_register_torchops.py index bd9ed6606e..2605182633 100644 --- a/thunder/tests/test_auto_register_torchops.py +++ b/thunder/tests/test_auto_register_torchops.py @@ -1,26 +1,41 @@ from functools import partial from unittest.mock import patch +from itertools import islice import pytest import thunder import thunder.torch.default_torch_ops as ops +from thunder.torch import _get_torch_function_name import torch from thunder.tests.framework import requiresCUDA, TorchExecutor from thunder.tests.make_tensor import make_tensor from thunder.tests.opinfos import get_opinfo, OpInfo from thunder.tests.test_einops import skipIfNoCUDA -from torch.testing._internal.common_device_type import skipCPUIfNoLapack +from torch.testing._internal.common_device_type import skipCPUIfNoLapack, skipCUDAIfNoMagma from torch.testing._internal.common_methods_invocations import op_db _name2func = {} -[_name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]] -[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]] -# Use the sample input from torch.xx to test torch.tensor.xx -[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]] -# NOTE some `opinfo.name`s don't have torch.xxx but are only used as torch.Tensor.xxx -torch_tensor_space_names = [v.__name__ for v in ops.torch_auto_registered_ops[torch.Tensor]] -_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func or opinfo.name in torch_tensor_space_names] +for m, fns in ops.torch_auto_registered_ops.items(): + m_name = m.__name__[len("torch.") :] if m.__name__.startswith("torch.") else m.__name__ + for fn in fns: + _name2func.setdefault(f"{m_name}.{_get_torch_function_name(m, fn)}", fn) + + +def get_opinfos_for_test(): + opinfos = [] + for opinfo in op_db: + if ( + opinfo.name in _name2func + or f"Tensor.{opinfo.name}" in _name2func + or any(alias.name in _name2func or f"Tensor.{alias.name}" in _name2func for alias in opinfo.aliases) + ): + opinfos.append(opinfo) + + return opinfos + + +_opinfos = get_opinfos_for_test() # Note that successfully catching an exception in this test is also noted as passed @@ -29,18 +44,18 @@ @pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference")) @pytest.mark.parametrize("device,", ["cuda", "cpu"]) def test_torch_ops_trace(device, requires_grad, op_info): - from itertools import islice - - # for op_info in op_infos: + if not op_info.supports_autograd and requires_grad: + pytest.skip("op_info.supports_autograd is False") if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA: - return + pytest.skip("float32 is not in op_info.dtypesIfCUDA") if device == "cpu" and not torch.float32 in op_info.dtypes: - return - # No cuda backend support + pytest.skip("float32 is not in op_info.dtypes") if op_info.name in ("nonzero_static",) and device == "cuda": - return + pytest.skip("Could not run 'aten::nonzero_static' with arguments from the 'CUDA' backend.") if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators: - return + pytest.skip("PyTorch compiled without Lapack") + if device == "cuda" and not torch.cuda.has_magma and skipCUDAIfNoMagma in op_info.decorators: + pytest.skip("PyTorch compiled without Magma") def get_method(op_info): # Check if we have registered this method. @@ -51,7 +66,8 @@ def get_method(op_info): return None funcs = [_name2func.get(op_info.name, None), get_method(op_info)] - for func in funcs: + funcs.extend(_name2func.get(alias.name, None) for alias in op_info.aliases) + for idx, func in enumerate(funcs): if func is None: continue # It takes too long, test only the first 5 sample inputs @@ -72,6 +88,8 @@ def get_method(op_info): ).startswith(f"Unsupported type:") break else: + # Get the alias name when testing for alias + cur_op_name = op_info.name if idx < 2 else op_info.aliases[idx - 2].name if requires_grad: trc = thunder.last_backward_traces(jfun)[-1] fwd_trc = thunder.last_traces(jfun)[-1] @@ -80,7 +98,7 @@ def get_method(op_info): outs = outs if isinstance(outs, tuple) else (outs,) if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs): continue - vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp" + vjp_op_name = f"{cur_op_name.split('.')[-1]}_vjp" if op_info.name == "mm": assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols) else: @@ -88,7 +106,7 @@ def get_method(op_info): else: fwd_trc = thunder.last_traces(jfun)[-1] assert any( - bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols + bsym.sym.name.endswith(cur_op_name.split(".")[-1]) and not bsym.subsymbols for bsym in fwd_trc.bound_symbols ) @@ -148,10 +166,7 @@ def test_nanogpt_block(self): import thunder.tests.nanogpt_model as nanogpt_model for op in _skip_ops_nanogpt: - if op.name == "gelu": - register_default_torch_op(op.torch_reference, torch) - else: - register_default_torch_op(op.torch_reference, torch.nn.functional) + register_default_torch_op(op.torch_reference, torch.nn.functional) self._tmp_update_jit_lookup(op.torch_reference) tdtype = torch.float32 device = torch.device("cuda") diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 6781e3bdde..ce3204520e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -10,7 +10,7 @@ from functools import partial, reduce, wraps from numbers import Number from typing import Any, overload -from types import NoneType +from types import NoneType, ModuleType from collections.abc import Callable import opt_einsum @@ -5422,18 +5422,42 @@ def register_default_torch_ops(): register_default_torch_op(fn, m) +def _get_torch_function_name(torch_module: ModuleType, torchfn: Callable): + # Handle special cases where torchfn.__name__ differs from the name used to call it in Python, + # e.g., `torch.nn.functional.logsigmoid.__name__` is 'log_sigmoid'. + special_cases = {torch.nn.functional.logsigmoid: "logsigmoid"} + # Operators in the following namespace have an extra prefix in their __name__ attribute compared to their Python call name. + # e.g., `torch.special.xlogy.__name__` is 'special_xlogy' instead of just 'xlogy'. + name_prefix_map = {torch.special: "special_", torch.linalg: "linalg_", torch.fft: "fft_"} + if function_name := special_cases.get(torchfn, None): + return function_name + if not (function_name := getattr(torchfn, "__name__", None)): + raise RuntimeError( + f"The function {torchfn} from the module {torch_module} does not have a __name__ attribute. Please ensure that you are passing a valid PyTorch function." + ) + prefix = name_prefix_map.get(torch_module, None) + if prefix and function_name.startswith(prefix): + function_name = function_name[len(prefix) :] + utils.check( + getattr(torch_module, function_name, None), + lambda: f"Incorrect function name {function_name} inferred for PyTorch function {torchfn} from module {torch_module}.", + ) + return function_name + + def register_default_torch_op(torchfn: Callable, torch_module): fn_meta = meta_adaptor(torchfn) _fn = langctx(Languages.TORCH)(fn_meta) + torchfn_name = _get_torch_function_name(torch_module, torchfn) sym = Symbol( - name=torchfn.__name__, + name=torchfn_name, meta=_fn, - id=f"{torch_module.__name__}.{torchfn.__name__}", + id=f"{torch_module.__name__}.{torchfn_name}", ) _torch_to_thunder_function_map[torchfn] = sym from thunder.executors.torchex import _always_executable, ex - op = ex.register_operator(torchfn.__name__, module=torch_module, meta=fn_meta) + op = ex.register_operator(torchfn_name, module=torch_module, meta=fn_meta) ex.register_implementation(sym, op, checker=_always_executable) from thunder.core.transforms import augmented_forward_impls, backward_impls @@ -5447,7 +5471,7 @@ def register_default_torch_op(torchfn: Callable, torch_module): _vjp_impl_wrapper = partial(_vjp_impl, torchfn) - bwd_op = ex.register_operator(torchfn.__name__ + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper) + bwd_op = ex.register_operator(torchfn_name + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper) ex.register_implementation(bwd_op.id, bwd_op, checker=_always_executable) backward_impls[sym.id] = bwd_op diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index 81f1b10b76..2ad4002d6b 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -399,7 +399,6 @@ torch.nn.functional.unfold, ], torch.Tensor: [ - torch.Tensor.positive, torch.Tensor.absolute, torch.Tensor.addbmm, torch.Tensor.addmm, @@ -637,7 +636,6 @@ torch.Tensor.vdot, torch.Tensor.vsplit, torch.Tensor.xlogy, - torch.Tensor.xpu, ], torch.special: [ torch.special.airy_ai, @@ -692,4 +690,69 @@ torch.special.xlog1py, torch.special.xlogy, ], + torch.linalg: [ + torch.linalg.cholesky, + torch.linalg.cholesky_ex, + torch.linalg.cond, + torch.linalg.cross, + torch.linalg.det, + torch.linalg.diagonal, + torch.linalg.eig, + torch.linalg.eigh, + torch.linalg.eigvals, + torch.linalg.eigvalsh, + torch.linalg.householder_product, + torch.linalg.inv, + torch.linalg.inv_ex, + torch.linalg.ldl_factor, + torch.linalg.ldl_factor_ex, + torch.linalg.ldl_solve, + torch.linalg.lstsq, + torch.linalg.lu, + torch.linalg.lu_factor, + torch.linalg.lu_factor_ex, + torch.linalg.lu_solve, + torch.linalg.matmul, + torch.linalg.matrix_exp, + torch.linalg.matrix_norm, + torch.linalg.matrix_power, + torch.linalg.matrix_rank, + torch.linalg.multi_dot, + torch.linalg.norm, + torch.linalg.pinv, + torch.linalg.qr, + torch.linalg.slogdet, + torch.linalg.solve, + torch.linalg.solve_ex, + torch.linalg.solve_triangular, + torch.linalg.svd, + torch.linalg.svdvals, + torch.linalg.tensorinv, + torch.linalg.tensorsolve, + torch.linalg.vander, + torch.linalg.vecdot, + torch.linalg.vector_norm, + ], + torch.fft: [ + torch.fft.fft, + torch.fft.fft2, + torch.fft.fftn, + torch.fft.fftshift, + torch.fft.hfft, + torch.fft.hfft2, + torch.fft.hfftn, + torch.fft.ifft, + torch.fft.ifft2, + torch.fft.ifftn, + torch.fft.ifftshift, + torch.fft.ihfft, + torch.fft.ihfft2, + torch.fft.ihfftn, + torch.fft.irfft, + torch.fft.irfft2, + torch.fft.irfftn, + torch.fft.rfft, + torch.fft.rfft2, + torch.fft.rfftn, + ], }