From 01a7167e10c88002f3e2bd92a912ab331563d8f0 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Mon, 5 Aug 2024 16:31:17 +0200 Subject: [PATCH] fix tests --- thunder/tests/test_auto_register_torchops.py | 163 ++++++++++--------- thunder/torch/__init__.py | 2 +- 2 files changed, 83 insertions(+), 82 deletions(-) diff --git a/thunder/tests/test_auto_register_torchops.py b/thunder/tests/test_auto_register_torchops.py index 389b1e2c85..7264523de7 100644 --- a/thunder/tests/test_auto_register_torchops.py +++ b/thunder/tests/test_auto_register_torchops.py @@ -3,117 +3,118 @@ import pytest import thunder +import thunder.torch.default_torch_ops as ops import torch from thunder.tests.framework import requiresCUDA, TorchExecutor from thunder.tests.make_tensor import make_tensor from thunder.tests.opinfos import get_opinfo, OpInfo from thunder.tests.test_einops import skipIfNoCUDA +from torch.testing._internal.common_device_type import skipCPUIfNoLapack +from torch.testing._internal.common_methods_invocations import op_db + +_name2func = {} +[_name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]] +[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]] +# Use the sample input from torch.xx to test torch.tensor.xx +[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]] +_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func] @skipIfNoCUDA +@pytest.mark.parametrize("op_info,", _opinfos, ids=list(map(lambda opinfo: opinfo.name, _opinfos))) @pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference")) @pytest.mark.parametrize("device,", ["cuda", "cpu"]) -def test_torch_ops_trace(device, requires_grad): +def test_torch_ops_trace(device, requires_grad, op_info): from itertools import islice - import thunder.torch.default_torch_ops as ops - from torch.testing._internal.common_methods_invocations import op_db - from torch.testing._internal.common_device_type import skipCPUIfNoLapack - - name2func = {} - [name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]] - [name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]] - # Use the sample input from torch.xx to test torch.tensor.xx - [name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]] - op_infos = [opinfo for opinfo in op_db if opinfo.name in name2func] total = 0 - cnt = 0 + exception_cnt = 0 suc = 0 - for op_info in op_infos: - if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA: - continue - if device == "cpu" and not torch.float32 in op_info.dtypes: + # for op_info in op_infos: + if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA: + return + if device == "cpu" and not torch.float32 in op_info.dtypes: + return + # No cuda backend support + if op_info.name in ("nonzero_static",) and device == "cuda": + return + if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators: + return + funcs = [_name2func[op_info.name], _name2func.get(f"Tensor.{op_info.name}", None)] + for func in funcs: + if func is None: continue - # No cuda backend support - if op_info.name in ("nonzero_static",) and device == "cuda": - continue - if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators: - continue - funcs = [name2func[op_info.name], name2func.get(f"Tensor.{op_info.name}", None)] - for func in funcs: - if func is None: - continue - total += 1 - # It takes too long, test only the first 5 sample inputs - gen = islice( - op_info.sample_inputs_func( - op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad - ), - 5, - ) - for sample in gen: - try: - jfun = thunder.jit(func) - out = jfun(sample.input, *sample.args, **sample.kwargs) - except Exception as e: - cnt += 1 - assert isinstance(e, NotImplementedError) - assert str(e).startswith(f"Exception encountered when doing automatic registration") or str( - e - ).startswith(f"Unsupported type:") - break - else: - if requires_grad: - trc = thunder.last_backward_traces(jfun)[-1] - fwd_trc = thunder.last_traces(jfun)[-1] - # skip if it is not differentiable - outs = fwd_trc.output[0]["output"] - outs = outs if isinstance(outs, tuple) else (outs,) - if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs): - continue - vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp" - if op_info.name == "mm": - assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols) - else: - assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols) - else: - fwd_trc = thunder.last_traces(jfun)[-1] - assert any( - bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols - for bsym in fwd_trc.bound_symbols - ) + total += 1 + # It takes too long, test only the first 5 sample inputs + gen = islice( + op_info.sample_inputs_func( + op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad + ), + 5, + ) + for sample in gen: + try: + jfun = thunder.jit(func) + out = jfun(sample.input, *sample.args, **sample.kwargs) + except Exception as e: + exception_cnt += 1 + assert isinstance(e, NotImplementedError) + assert str(e).startswith(f"Exception encountered when doing automatic registration") or str( + e + ).startswith(f"Unsupported type:") + break else: - suc += 1 + if requires_grad: + trc = thunder.last_backward_traces(jfun)[-1] + fwd_trc = thunder.last_traces(jfun)[-1] + # skip if it is not differentiable + outs = fwd_trc.output[0]["output"] + outs = outs if isinstance(outs, tuple) else (outs,) + if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs): + continue + vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp" + if op_info.name == "mm": + assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols) + else: + assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols) + else: + fwd_trc = thunder.last_traces(jfun)[-1] + assert any( + bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols + for bsym in fwd_trc.bound_symbols + ) + else: + suc += 1 # Replace manual registration of some operations with automatic registration for network test cases -skip_ops_nanogpt = [ +_skip_ops_nanogpt = [ get_opinfo("layer_norm"), get_opinfo("linear"), get_opinfo("gelu"), get_opinfo("scaled_dot_product_attention"), ] -skip_ops_alexnet = [ +_skip_ops_alexnet = [ get_opinfo("conv2d"), get_opinfo("linear"), get_opinfo("adaptive_avg_pool2d"), get_opinfo("max_pool2d"), ] -disable_opinfos = skip_ops_nanogpt + skip_ops_alexnet -tmp1 = dict(thunder.core.jit_ext._general_jit_lookaside_map) -list(tmp1.pop(k.torch_reference, None) for k in disable_opinfos) -tmp2 = dict(thunder.torch._torch_to_thunder_function_map) -list(tmp2.pop(k.torch_reference, None) for k in disable_opinfos) -tmp3 = dict(thunder.core.jit_ext._minimal_lookaside_map) -list(tmp3.pop(k.torch_reference, None) for k in disable_opinfos) +_disable_opinfos = _skip_ops_nanogpt + _skip_ops_alexnet +_tmp_general_jit_lookaside_map = dict(thunder.core.jit_ext._general_jit_lookaside_map) +list(_tmp_general_jit_lookaside_map.pop(k.torch_reference, None) for k in _disable_opinfos) +_tmp_torch_to_thunder_function_map = dict(thunder.torch._torch_to_thunder_function_map) +list(_tmp_torch_to_thunder_function_map.pop(k.torch_reference, None) for k in _disable_opinfos) +_tmp_minimal_lookaside_map = dict(thunder.core.jit_ext._minimal_lookaside_map) +list(_tmp_minimal_lookaside_map.pop(k.torch_reference, None) for k in _disable_opinfos) from thunder.torch import register_default_torch_op # mock all the global variables that are modified during registration -@patch.dict(thunder.core.jit_ext._general_jit_lookaside_map, tmp1, clear=True) -@patch.dict(thunder.torch._torch_to_thunder_function_map, tmp2, clear=True) -@patch.dict(thunder.core.jit_ext._minimal_lookaside_map, tmp3, clear=True) +@patch.dict(thunder.core.jit_ext._general_jit_lookaside_map, _tmp_general_jit_lookaside_map, clear=True) +@patch.dict(thunder.torch._torch_to_thunder_function_map, _tmp_torch_to_thunder_function_map, clear=True) +@patch.dict(thunder.core.jit_ext._minimal_lookaside_map, _tmp_minimal_lookaside_map, clear=True) @patch.dict(thunder.executors.torchex.ex._implmap, {}) @patch.dict(thunder.executors.torchex.ex._opmap, {}) @patch.dict(thunder.core.transforms.augmented_forward_impls, {}) @@ -141,7 +142,7 @@ def _tmp_update_jit_lookup(self, torchfn): def test_nanogpt_block(self): import thunder.tests.nanogpt_model as nanogpt_model - for op in skip_ops_nanogpt: + for op in _skip_ops_nanogpt: if op.name == "gelu": register_default_torch_op(op.torch_reference, torch) else: @@ -160,7 +161,7 @@ def test_nanogpt_block(self): cache_entry, _, _ = thunder.compile_data(jitted).get_computation_and_inputs(x) bwd_trcs = cache_entry.backward_traces - for op in skip_ops_nanogpt: + for op in _skip_ops_nanogpt: vjp_op_name = f"{op.name}_vjp" assert any(bsym.sym.name == vjp_op_name for bsym in bwd_trcs[-1].bound_symbols) @@ -168,7 +169,7 @@ def test_nanogpt_block(self): def test_alexnet(self): torchvision = pytest.importorskip("torchvision") - for op in skip_ops_alexnet: + for op in _skip_ops_alexnet: register_default_torch_op(op.torch_reference, torch.nn.functional) self._tmp_update_jit_lookup(op.torch_reference) tdtype = torch.float32 @@ -182,6 +183,6 @@ def test_alexnet(self): cache_entry, _, _ = thunder.compile_data(jitted).get_computation_and_inputs(x) bwd_trcs = cache_entry.backward_traces - for op in skip_ops_alexnet: + for op in _skip_ops_alexnet: vjp_op_name = f"{op.name}_vjp" assert any(bsym.sym.name == vjp_op_name for bsym in bwd_trcs[-1].bound_symbols) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2caf32ea21..ab4401849a 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5401,7 +5401,7 @@ def register_default_torch_op(torchfn: Callable, m): backward_impls[sym.id] = bwd_op -# Note this function should be used in side the fake mode context manager +# Note this function should be used inside the fake mode context manager def _get_fake_arg(inp: Any): if inp is None: return inp