Skip to content

Commit

Permalink
UX: Add print_auto_registered_torch_ops to query the auto-registere…
Browse files Browse the repository at this point in the history
…d ops on jitted function
  • Loading branch information
kiya00 committed Aug 22, 2024
1 parent e9a78e9 commit 68ce500
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
11 changes: 11 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,17 @@ def last_compile_options(fn: Callable, /) -> None:
print(f"\t{option}")


def print_auto_registered_torch_ops(fn: Callable, /) -> set[str] | None:
"""Returns a set of auto-registered Torch operator names present in the given JIT-compiled function."""
op_names = set()
trc = last_traces(fn)[0]
for bsym in trc.bound_symbols:
if (meta_func_name := getattr(bsym.sym.meta, "__name__", None)) and meta_func_name == "meta_func":
op_names.add(bsym.sym.id)
# import pdb;pdb.set_trace()
return op_names if op_names else None


# TODO (mruberry) Update this
def _grad_transform(trace):
grad_fwd_trace = from_trace(trace)
Expand Down
19 changes: 11 additions & 8 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,19 @@ def test_alexnet(self):


@instantiate(dtypes=NOTHING)
def test_compile_stats_auto_reg_record(executor, device: str, dtype):
def test_query_autoreg_ops(executor, device: str, _):
def fn(a):
x = torch.special.gammaln(torch.special.zeta(torch.special.gammaln(a), a))
return torch.special.erf(x)

cfn = executor.make_callable(fn)
def fn_none(a):
return torch.nn.functional.relu(a)

a = make_tensor((2, 2), device=device, dtype=torch.float32)
cfn(a)
cs = thunder.compile_stats(cfn)
assert len(cs.auto_registered_torch_operators) == 2
assert "torch.special.gammaln" in cs.auto_registered_torch_operators
assert "torch.special.erf" in cs.auto_registered_torch_operators
expected = ({"torch.special.erf", "torch.special.gammaln"}, None)
for fn, expect in zip((fn, fn_none), expected):
cfn = executor.make_callable(fn)

a = make_tensor((2, 2), device=device, dtype=torch.float32)
cfn(a)
ops = thunder.print_auto_registered_torch_ops(cfn)
assert expect == ops

0 comments on commit 68ce500

Please sign in to comment.