Skip to content

Commit

Permalink
Fix test: use first 5 sample inputs to run faster
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Jul 29, 2024
1 parent 4f8afa9 commit f6126fc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 46 deletions.
85 changes: 47 additions & 38 deletions thunder/tests/test_fallback_to_torchops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,26 @@
@skipIfNoCUDA
@pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference"))
@pytest.mark.parametrize("device,", ["cuda", "cpu"])
def test_torch_ops_backward(device, requires_grad):
from itertools import chain
def test_torch_ops(device, requires_grad):
from itertools import islice

import thunder.torch.default_torch_ops as ops
from torch.testing._internal.common_methods_invocations import op_db

name2func = {}
[name2func.setdefault(v.__name__, v) for v in ops.torch_fallback_ops[torch]]
[name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_fallback_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_fallback_ops[torch.Tensor]]
print(f"auto reg ops: {len(name2func)}")
# [name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_fallback_ops[torch.Tensor]]
op_infos = [opinfo for opinfo in op_db if opinfo.name in name2func.keys()]
total = len(op_infos)
total = 0
cnt = 0
suc = 0
for op_info in op_infos:
if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA:
total -= 1
continue
if device == "cpu" and not torch.float32 in op_info.dtypes:
total -= 1
continue

# No cuda backend support
Expand All @@ -43,40 +43,49 @@ def test_torch_ops_backward(device, requires_grad):
# ValueError: not enough values to unpack (expected 6, got 2)
if op_info.name == "histogramdd" and device == "cpu" and not requires_grad:
continue
for sample in op_info.sample_inputs_func(
op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad
):
if op_info.name == "searchsorted" and (requires_grad and not sample.input.requires_grad):
if op_info.name == "searchsorted" and (requires_grad and not sample.input.requires_grad):
continue
funcs = [name2func[op_info.name], name2func.get(f"Tensor.{op_info.name}", None)]
for func in funcs:
if func is None:
continue
try:
jfun = thunder.jit(name2func[op_info.name])
out = jfun(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
cnt += 1
assert isinstance(e, NotImplementedError)
assert str(e).startswith(f"Exception encountered when doing automatic registration") or str(
e
).startswith(f"Unsupported type:")
# print(e)
# print(f"unsupported: {op_info.name}")
# print("--------------------")
break
total += 1
# It takes too long, test only the first 5 sample inputs
gen = islice(
op_info.sample_inputs_func(
op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad
),
5,
)
for sample in gen:
try:
jfun = thunder.jit(func)
out = jfun(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
cnt += 1
assert isinstance(e, NotImplementedError)
assert str(e).startswith(f"Exception encountered when doing automatic registration") or str(
e
).startswith(f"Unsupported type:")
break
else:
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
trcf = thunder.last_traces(jfun)[-1]
# skip if it is not differentiable
outs = trcf.output[0]["output"]
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
trcf = thunder.last_traces(jfun)[-1]
# skip if it is not differentiable
outs = trcf.output[0]["output"]
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)

print(f"needs manual registeration: {cnt}, total: {total}")
suc += 1

print(f"total number of ops with sample input: {total}, success: {suc}, needs manual registeration: {cnt}")


# Replace manual registration of some operations with automatic registration for network test cases
Expand Down
16 changes: 8 additions & 8 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5364,15 +5364,15 @@ def wait(slf) -> None:
#


def _is_differentiable(arg):
def _is_differentiable(arg: Any):
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(arg, (torch.Tensor, FakeTensor, TensorProxy)):
return dtypes.is_inexact_dtype(to_dtype(arg.dtype))
return False


def _make_differentiable_wrapper(func, *args, **kwargs):
def _make_differentiable_wrapper(func: Callable, *args, **kwargs):
from thunder.core.pytree import tree_flatten, tree_unflatten

flat_args, spec = tree_flatten((args, kwargs))
Expand All @@ -5396,7 +5396,7 @@ def register_default_torch_ops():
register_default_torch_op(fn, meta_adaptor(fn), m)


def register_default_torch_op(torchfn, fn_meta, m):
def register_default_torch_op(torchfn: Callable, fn_meta: Callable, m):
_fn = langctx(Languages.TORCH)(fn_meta)
sym = Symbol(
name=torchfn.__name__,
Expand Down Expand Up @@ -5431,7 +5431,7 @@ def _vjp_impl(residules, *gs) -> torch.Tensor:
backward_impls[sym.id] = bwd_op


def _get_fake_arg(inp, fake_mode):
def _get_fake_arg(inp: Any, fake_mode):
from thunder.core.devices import to_torch_device
from thunder.core.proxies import TupleProxy, ListProxy, DictProxy

Expand Down Expand Up @@ -5465,7 +5465,7 @@ def _get_fake_arg(inp, fake_mode):
raise NotImplementedError(f"Unsupported type: {type(inp)}")


def _fake_type_to_thunder(inp):
def _fake_type_to_thunder(inp: Any):
from builtins import type

from thunder.core.proxies import _cls_to_number_proxy_map, numberproxy
Expand Down Expand Up @@ -5494,7 +5494,7 @@ def _fake_type_to_thunder(inp):
raise NotImplementedError(f"Unsupported type: {type(inp)}")


def augmented_forward_adaptor(sym_op):
def augmented_forward_adaptor(sym_op: Callable):
def wrapper(*args, **kwargs):
from thunder.core.transforms import VJPDual

Expand All @@ -5506,7 +5506,7 @@ def wrapper(*args, **kwargs):
return wrapper


def backward_adaptor(torch_func):
def backward_adaptor(torch_func: Callable):
def wrapper(saved_for_backward, *grad_output):
inp_args, inp_kwargs = saved_for_backward
from thunder.core.pytree import tree_flatten, tree_unflatten
Expand Down Expand Up @@ -5540,7 +5540,7 @@ def wrapper(saved_for_backward, *grad_output):
return wrapper


def meta_adaptor(torch_func):
def meta_adaptor(torch_func: Callable):
def wrapper(*args, **kwargs):
from thunder.core.pytree import tree_flatten, tree_unflatten
from torch._subclasses.fake_tensor import FakeTensorMode
Expand Down

0 comments on commit f6126fc

Please sign in to comment.