Skip to content

Commit

Permalink
Add thunderfx API (#1529)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Dec 10, 2024
1 parent 087637f commit 06e7b36
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
6 changes: 2 additions & 4 deletions thunder/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from thunder.dynamo.compiler import ThunderCompiler
from thunder.dynamo.compiler import ThunderCompiler, thunderfx


__all__ = [
"ThunderCompiler",
]
__all__ = ["ThunderCompiler", "thunderfx"]
32 changes: 32 additions & 0 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from looseversion import LooseVersion
from typing import TYPE_CHECKING
import warnings
import inspect

import torch

from thunder.core.baseutils import run_once
from thunder.core.utils import safe_zip
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType
from thunder.dynamo.splitter import _splitter
from thunder.core.utils import check

if TYPE_CHECKING:
from thunder.dynamo.utils import SubgraphInfo
from os import PathLike
from collections.abc import Callable


@run_once
Expand Down Expand Up @@ -127,3 +130,32 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes
f"graph{graph_idx}_{cur_name}",
use_pytest_benchmark,
)


def thunderfx(fn: Callable, /, **kwargs) -> Callable:
"""Compiles a callable (function or model) by using Thunder as the backend of :func:`torch.compile`
Args:
fn: A :class:`~torch.nn.Module` or a function to compile.
Keyword Args:
**kwargs: a dictionary of options to pass to :func:`torch.compile` and :func:`thunder.jit`.
Returns:
The compiled callable
"""
import thunder

torch_compile_kwarg_names = inspect.getfullargspec(torch.compile).kwonlyargs
thunder_jit_kwarg_names = inspect.getfullargspec(thunder.jit).kwonlyargs
overlap = [kwarg_name for kwarg_name in thunder_jit_kwarg_names if kwarg_name in torch_compile_kwarg_names]
check(
not overlap,
lambda: f"There are overlapping kwargs between thunder.jit and torch.compile: {overlap}",
ValueError,
)

torch_compile_options = {k: v for k, v in kwargs.items() if k in torch_compile_kwarg_names}
thunder_options = {k: v for k, v in kwargs.items() if k not in torch_compile_kwarg_names}

backend = ThunderCompiler(**thunder_options)
compiled = torch.compile(fn, backend=backend, **torch_compile_options)
compiled._backend = backend
return compiled
23 changes: 23 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,26 @@ def check(file_name, cmd):
cmd = "pytest" if use_pytest_benchmark else "python"
for fname in [s1, s2, s3]:
check(fname, cmd)


def test_thunderfx():
from thunder.dynamo import thunderfx

def foo(x):
return torch.sin(x) + torch.cos(x)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
cfoo = thunderfx(foo)
cfoo(x)
thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns
assert len(thunder_compiled_fns) == 1
assert last_traces(thunder_compiled_fns[0])

from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

cfoo = thunderfx(foo, dynamic=True, transforms=[NvtxProfileTransform()])
cfoo(x)
thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns
assert len(thunder_compiled_fns) == 1
trc = last_traces(thunder_compiled_fns[-1])[-1]
assert any(bsym.sym.id == "nvtx_range_push" for bsym in trc.bound_symbols)

0 comments on commit 06e7b36

Please sign in to comment.