Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Aug 5, 2024
1 parent 445147d commit 01a7167
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 82 deletions.
163 changes: 82 additions & 81 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,117 +3,118 @@

import pytest
import thunder
import thunder.torch.default_torch_ops as ops
import torch

from thunder.tests.framework import requiresCUDA, TorchExecutor
from thunder.tests.make_tensor import make_tensor
from thunder.tests.opinfos import get_opinfo, OpInfo
from thunder.tests.test_einops import skipIfNoCUDA
from torch.testing._internal.common_device_type import skipCPUIfNoLapack
from torch.testing._internal.common_methods_invocations import op_db

_name2func = {}
[_name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]]
[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]]
_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func]


@skipIfNoCUDA
@pytest.mark.parametrize("op_info,", _opinfos, ids=list(map(lambda opinfo: opinfo.name, _opinfos)))
@pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference"))
@pytest.mark.parametrize("device,", ["cuda", "cpu"])
def test_torch_ops_trace(device, requires_grad):
def test_torch_ops_trace(device, requires_grad, op_info):
from itertools import islice

import thunder.torch.default_torch_ops as ops
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import skipCPUIfNoLapack

name2func = {}
[name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]]
[name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]]
op_infos = [opinfo for opinfo in op_db if opinfo.name in name2func]
total = 0
cnt = 0
exception_cnt = 0
suc = 0
for op_info in op_infos:
if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA:
continue
if device == "cpu" and not torch.float32 in op_info.dtypes:
# for op_info in op_infos:
if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA:
return
if device == "cpu" and not torch.float32 in op_info.dtypes:
return
# No cuda backend support
if op_info.name in ("nonzero_static",) and device == "cuda":
return
if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators:
return
funcs = [_name2func[op_info.name], _name2func.get(f"Tensor.{op_info.name}", None)]
for func in funcs:
if func is None:
continue
# No cuda backend support
if op_info.name in ("nonzero_static",) and device == "cuda":
continue
if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators:
continue
funcs = [name2func[op_info.name], name2func.get(f"Tensor.{op_info.name}", None)]
for func in funcs:
if func is None:
continue
total += 1
# It takes too long, test only the first 5 sample inputs
gen = islice(
op_info.sample_inputs_func(
op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad
),
5,
)
for sample in gen:
try:
jfun = thunder.jit(func)
out = jfun(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
cnt += 1
assert isinstance(e, NotImplementedError)
assert str(e).startswith(f"Exception encountered when doing automatic registration") or str(
e
).startswith(f"Unsupported type:")
break
else:
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
# skip if it is not differentiable
outs = fwd_trc.output[0]["output"]
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)
total += 1
# It takes too long, test only the first 5 sample inputs
gen = islice(
op_info.sample_inputs_func(
op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad
),
5,
)
for sample in gen:
try:
jfun = thunder.jit(func)
out = jfun(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
exception_cnt += 1
assert isinstance(e, NotImplementedError)
assert str(e).startswith(f"Exception encountered when doing automatic registration") or str(
e
).startswith(f"Unsupported type:")
break
else:
suc += 1
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
# skip if it is not differentiable
outs = fwd_trc.output[0]["output"]
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)
else:
suc += 1


# Replace manual registration of some operations with automatic registration for network test cases
skip_ops_nanogpt = [
_skip_ops_nanogpt = [
get_opinfo("layer_norm"),
get_opinfo("linear"),
get_opinfo("gelu"),
get_opinfo("scaled_dot_product_attention"),
]
skip_ops_alexnet = [
_skip_ops_alexnet = [
get_opinfo("conv2d"),
get_opinfo("linear"),
get_opinfo("adaptive_avg_pool2d"),
get_opinfo("max_pool2d"),
]
disable_opinfos = skip_ops_nanogpt + skip_ops_alexnet
tmp1 = dict(thunder.core.jit_ext._general_jit_lookaside_map)
list(tmp1.pop(k.torch_reference, None) for k in disable_opinfos)
tmp2 = dict(thunder.torch._torch_to_thunder_function_map)
list(tmp2.pop(k.torch_reference, None) for k in disable_opinfos)
tmp3 = dict(thunder.core.jit_ext._minimal_lookaside_map)
list(tmp3.pop(k.torch_reference, None) for k in disable_opinfos)
_disable_opinfos = _skip_ops_nanogpt + _skip_ops_alexnet
_tmp_general_jit_lookaside_map = dict(thunder.core.jit_ext._general_jit_lookaside_map)
list(_tmp_general_jit_lookaside_map.pop(k.torch_reference, None) for k in _disable_opinfos)
_tmp_torch_to_thunder_function_map = dict(thunder.torch._torch_to_thunder_function_map)
list(_tmp_torch_to_thunder_function_map.pop(k.torch_reference, None) for k in _disable_opinfos)
_tmp_minimal_lookaside_map = dict(thunder.core.jit_ext._minimal_lookaside_map)
list(_tmp_minimal_lookaside_map.pop(k.torch_reference, None) for k in _disable_opinfos)
from thunder.torch import register_default_torch_op


# mock all the global variables that are modified during registration
@patch.dict(thunder.core.jit_ext._general_jit_lookaside_map, tmp1, clear=True)
@patch.dict(thunder.torch._torch_to_thunder_function_map, tmp2, clear=True)
@patch.dict(thunder.core.jit_ext._minimal_lookaside_map, tmp3, clear=True)
@patch.dict(thunder.core.jit_ext._general_jit_lookaside_map, _tmp_general_jit_lookaside_map, clear=True)
@patch.dict(thunder.torch._torch_to_thunder_function_map, _tmp_torch_to_thunder_function_map, clear=True)
@patch.dict(thunder.core.jit_ext._minimal_lookaside_map, _tmp_minimal_lookaside_map, clear=True)
@patch.dict(thunder.executors.torchex.ex._implmap, {})
@patch.dict(thunder.executors.torchex.ex._opmap, {})
@patch.dict(thunder.core.transforms.augmented_forward_impls, {})
Expand Down Expand Up @@ -141,7 +142,7 @@ def _tmp_update_jit_lookup(self, torchfn):
def test_nanogpt_block(self):
import thunder.tests.nanogpt_model as nanogpt_model

for op in skip_ops_nanogpt:
for op in _skip_ops_nanogpt:
if op.name == "gelu":
register_default_torch_op(op.torch_reference, torch)
else:
Expand All @@ -160,15 +161,15 @@ def test_nanogpt_block(self):

cache_entry, _, _ = thunder.compile_data(jitted).get_computation_and_inputs(x)
bwd_trcs = cache_entry.backward_traces
for op in skip_ops_nanogpt:
for op in _skip_ops_nanogpt:
vjp_op_name = f"{op.name}_vjp"
assert any(bsym.sym.name == vjp_op_name for bsym in bwd_trcs[-1].bound_symbols)

@requiresCUDA
def test_alexnet(self):
torchvision = pytest.importorskip("torchvision")

for op in skip_ops_alexnet:
for op in _skip_ops_alexnet:
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
Expand All @@ -182,6 +183,6 @@ def test_alexnet(self):

cache_entry, _, _ = thunder.compile_data(jitted).get_computation_and_inputs(x)
bwd_trcs = cache_entry.backward_traces
for op in skip_ops_alexnet:
for op in _skip_ops_alexnet:
vjp_op_name = f"{op.name}_vjp"
assert any(bsym.sym.name == vjp_op_name for bsym in bwd_trcs[-1].bound_symbols)
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5401,7 +5401,7 @@ def register_default_torch_op(torchfn: Callable, m):
backward_impls[sym.id] = bwd_op


# Note this function should be used in side the fake mode context manager
# Note this function should be used inside the fake mode context manager
def _get_fake_arg(inp: Any):
if inp is None:
return inp
Expand Down

0 comments on commit 01a7167

Please sign in to comment.