diff --git a/thunder/dynamo/__init__.py b/thunder/dynamo/__init__.py index 3993f6f18d..6458c5d0eb 100644 --- a/thunder/dynamo/__init__.py +++ b/thunder/dynamo/__init__.py @@ -1,6 +1,4 @@ -from thunder.dynamo.compiler import ThunderCompiler +from thunder.dynamo.compiler import ThunderCompiler, thunderfx -__all__ = [ - "ThunderCompiler", -] +__all__ = ["ThunderCompiler", "thunderfx"] diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index c7958be9b1..61562fed85 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -3,6 +3,7 @@ from looseversion import LooseVersion from typing import TYPE_CHECKING import warnings +import inspect import torch @@ -10,10 +11,12 @@ from thunder.core.utils import safe_zip from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType from thunder.dynamo.splitter import _splitter +from thunder.core.utils import check if TYPE_CHECKING: from thunder.dynamo.utils import SubgraphInfo from os import PathLike + from collections.abc import Callable @run_once @@ -127,3 +130,32 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes f"graph{graph_idx}_{cur_name}", use_pytest_benchmark, ) + + +def thunderfx(fn: Callable, /, **kwargs) -> Callable: + """Compiles a callable (function or model) by using Thunder as the backend of :func:`torch.compile` + Args: + fn: A :class:`~torch.nn.Module` or a function to compile. + Keyword Args: + **kwargs: a dictionary of options to pass to :func:`torch.compile` and :func:`thunder.jit`. + Returns: + The compiled callable + """ + import thunder + + torch_compile_kwarg_names = inspect.getfullargspec(torch.compile).kwonlyargs + thunder_jit_kwarg_names = inspect.getfullargspec(thunder.jit).kwonlyargs + overlap = [kwarg_name for kwarg_name in thunder_jit_kwarg_names if kwarg_name in torch_compile_kwarg_names] + check( + not overlap, + lambda: f"There are overlapping kwargs between thunder.jit and torch.compile: {overlap}", + ValueError, + ) + + torch_compile_options = {k: v for k, v in kwargs.items() if k in torch_compile_kwarg_names} + thunder_options = {k: v for k, v in kwargs.items() if k not in torch_compile_kwarg_names} + + backend = ThunderCompiler(**thunder_options) + compiled = torch.compile(fn, backend=backend, **torch_compile_options) + compiled._backend = backend + return compiled diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 0ae470b1fc..0c51fce3cf 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -10,7 +10,7 @@ from looseversion import LooseVersion from thunder import dtypes -from thunder.dynamo import ThunderCompiler +from thunder.dynamo import ThunderCompiler, thunderfx from thunder.dynamo.utils import CompilerType from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking from thunder import last_traces @@ -930,3 +930,25 @@ def check(file_name, cmd): cmd = "pytest" if use_pytest_benchmark else "python" for fname in [s1, s2, s3]: check(fname, cmd) + + +@requiresCUDA +def test_thunderfx(): + def foo(x): + return torch.sin(x) + torch.cos(x) + + x = torch.randn(4, 4, device="cuda", requires_grad=True) + cfoo = thunderfx(foo) + cfoo(x) + thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns + assert len(thunder_compiled_fns) == 1 + assert last_traces(thunder_compiled_fns[0]) + + from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform + + cfoo = thunderfx(foo, dynamic=True, transforms=[NvtxProfileTransform()]) + cfoo(x) + thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns + assert len(thunder_compiled_fns) == 1 + trc = last_traces(thunder_compiled_fns[-1])[-1] + assert any(bsym.sym.id == "nvtx_range_push" for bsym in trc.bound_symbols)