Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auto-registered torch.special operators #979

Merged
merged 7 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,23 @@
[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]]
# NOTE some `opinfo.name`s don't have torch.xxx but are only used as torch.Tensor.xxx
torch_tensor_space_names = [v.__name__ for v in ops.torch_auto_registered_ops[torch.Tensor]]
_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func or opinfo.name in torch_tensor_space_names]
[_name2func.setdefault(f"special.{v.__name__.split('_')[1]}", v) for v in ops.torch_auto_registered_ops[torch.special]]
kiya00 marked this conversation as resolved.
Show resolved Hide resolved


def get_opinfos_for_test():
opinfos = []
for opinfo in op_db:
if (
opinfo.name in _name2func
or f"Tensor.{opinfo.name}" in _name2func
or any(alias.name in _name2func or f"Tensor.{alias.name}" in _name2func for alias in opinfo.aliases)
):
opinfos.append(opinfo)

return opinfos


_opinfos = get_opinfos_for_test()


# Note that successfully catching an exception in this test is also noted as passed
Expand Down Expand Up @@ -51,7 +65,8 @@ def get_method(op_info):
return None

funcs = [_name2func.get(op_info.name, None), get_method(op_info)]
for func in funcs:
funcs.extend(_name2func.get(alias.name, None) for alias in op_info.aliases)
for idx, func in enumerate(funcs):
if func is None:
continue
# It takes too long, test only the first 5 sample inputs
Expand All @@ -72,6 +87,8 @@ def get_method(op_info):
).startswith(f"Unsupported type:")
break
else:
# Get the alias name when testing for alias
cur_op_name = op_info.name if idx < 2 else op_info.aliases[idx - 2].name
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
Expand All @@ -80,15 +97,15 @@ def get_method(op_info):
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
vjp_op_name = f"{cur_op_name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
bsym.sym.name.endswith(cur_op_name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)

Expand Down
9 changes: 5 additions & 4 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5425,15 +5425,16 @@ def register_default_torch_ops():
def register_default_torch_op(torchfn: Callable, torch_module):
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
fn_meta = meta_adaptor(torchfn)
_fn = langctx(Languages.TORCH)(fn_meta)
torchfn_name = torchfn.__name__.split("_")[1] if torch_module == torch.special else torchfn.__name__
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
sym = Symbol(
name=torchfn.__name__,
name=torchfn_name,
meta=_fn,
id=f"{torch_module.__name__}.{torchfn.__name__}",
id=f"{torch_module.__name__}.{torchfn_name}",
)
_torch_to_thunder_function_map[torchfn] = sym
from thunder.executors.torchex import _always_executable, ex

op = ex.register_operator(torchfn.__name__, module=torch_module, meta=fn_meta)
op = ex.register_operator(torchfn_name, module=torch_module, meta=fn_meta)
ex.register_implementation(sym, op, checker=_always_executable)

from thunder.core.transforms import augmented_forward_impls, backward_impls
Expand All @@ -5447,7 +5448,7 @@ def register_default_torch_op(torchfn: Callable, torch_module):

_vjp_impl_wrapper = partial(_vjp_impl, torchfn)

bwd_op = ex.register_operator(torchfn.__name__ + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper)
bwd_op = ex.register_operator(torchfn_name + "_vjp", meta=backward_adaptor(torchfn), fn=_vjp_impl_wrapper)
ex.register_implementation(bwd_op.id, bwd_op, checker=_always_executable)
backward_impls[sym.id] = bwd_op

Expand Down
Loading