diff --git a/thunder/__init__.py b/thunder/__init__.py index 9686697558..1313c50178 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -1014,6 +1014,17 @@ def last_compile_options(fn: Callable, /) -> None: print(f"\t{option}") +def print_auto_registered_torch_ops(fn: Callable, /) -> set[str] | None: + """Returns a set of auto-registered Torch operator names present in the given JIT-compiled function.""" + op_names = set() + trc = last_traces(fn)[0] + for bsym in trc.bound_symbols: + if (meta_func_name := getattr(bsym.sym.meta, "__name__", None)) and meta_func_name == "meta_func": + op_names.add(bsym.sym.id) + # import pdb;pdb.set_trace() + return op_names if op_names else None + + # TODO (mruberry) Update this def _grad_transform(trace): grad_fwd_trace = from_trace(trace) diff --git a/thunder/tests/test_auto_register_torchops.py b/thunder/tests/test_auto_register_torchops.py index e19e23675a..64fc825ab7 100644 --- a/thunder/tests/test_auto_register_torchops.py +++ b/thunder/tests/test_auto_register_torchops.py @@ -209,16 +209,19 @@ def test_alexnet(self): @instantiate(dtypes=NOTHING) -def test_compile_stats_auto_reg_record(executor, device: str, dtype): +def test_query_autoreg_ops(executor, device: str, _): def fn(a): x = torch.special.gammaln(torch.special.zeta(torch.special.gammaln(a), a)) return torch.special.erf(x) - cfn = executor.make_callable(fn) + def fn_none(a): + return torch.nn.functional.relu(a) - a = make_tensor((2, 2), device=device, dtype=torch.float32) - cfn(a) - cs = thunder.compile_stats(cfn) - assert len(cs.auto_registered_torch_operators) == 2 - assert "torch.special.gammaln" in cs.auto_registered_torch_operators - assert "torch.special.erf" in cs.auto_registered_torch_operators + expected = ({"torch.special.erf", "torch.special.gammaln"}, None) + for fn, expect in zip((fn, fn_none), expected): + cfn = executor.make_callable(fn) + + a = make_tensor((2, 2), device=device, dtype=torch.float32) + cfn(a) + ops = thunder.print_auto_registered_torch_ops(cfn) + assert expect == ops