Skip to content

Commit

Permalink
fix for testing alias
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Aug 16, 2024
1 parent a19c44f commit 4a29df5
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_method(op_info):

funcs = [_name2func.get(op_info.name, None), get_method(op_info)]
funcs.extend(_name2func.get(alias.name, None) for alias in op_info.aliases)
for func in funcs:
for idx, func in enumerate(funcs):
if func is None:
continue
# It takes too long, test only the first 5 sample inputs
Expand All @@ -87,6 +87,8 @@ def get_method(op_info):
).startswith(f"Unsupported type:")
break
else:
# Get the alias name when testing for alias
cur_op_name = op_info.name if idx < 2 else op_info.aliases[idx - 2].name
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
Expand All @@ -95,15 +97,15 @@ def get_method(op_info):
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
vjp_op_name = f"{cur_op_name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
bsym.sym.name.endswith(cur_op_name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)

Expand Down

0 comments on commit 4a29df5

Please sign in to comment.