diff --git a/thunder/dynamo/__init__.py b/thunder/dynamo/__init__.py index 3993f6f18d..6458c5d0eb 100644 --- a/thunder/dynamo/__init__.py +++ b/thunder/dynamo/__init__.py @@ -1,6 +1,4 @@ -from thunder.dynamo.compiler import ThunderCompiler +from thunder.dynamo.compiler import ThunderCompiler, thunderfx -__all__ = [ - "ThunderCompiler", -] +__all__ = ["ThunderCompiler", "thunderfx"] diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index c7958be9b1..510873076b 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -3,6 +3,7 @@ from looseversion import LooseVersion from typing import TYPE_CHECKING import warnings +import inspect import torch @@ -10,18 +11,12 @@ from thunder.core.utils import safe_zip from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType from thunder.dynamo.splitter import _splitter +from thunder.core.utils import check if TYPE_CHECKING: from thunder.dynamo.utils import SubgraphInfo from os import PathLike - - -@run_once -def _warn_thunder_compiler(): - warnings.warn( - "The ThunderCompiler is in active development and may not work as expected." - + " Please report any issues you encounter to the Lightning Thunder team." - ) + from collections.abc import Callable class ThunderCompiler: @@ -32,9 +27,7 @@ def __init__(self, **thunder_options): function. Keyword arguments: - thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`, - it accepts ``torch_inductor_options`` which are passed to :func:`torch.compile` if part of the graph - is not supported by thunder. + thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Example: >>> import torch @@ -52,8 +45,6 @@ def __init__(self, **thunder_options): """ from thunder import jit - _warn_thunder_compiler() - if LooseVersion(torch.__version__) < LooseVersion("2.4.0"): # NOTE: PyTorch 2.3 or lower has bug in `split_module` function used in splitter. # See https://github.com/Lightning-AI/lightning-thunder/pull/1075#issuecomment-2324918409 @@ -67,11 +58,9 @@ def __init__(self, **thunder_options): # Ref to the documentation of `SubgraphInfo` to know more about the information it contains. self.subgraph_infos: list[SubgraphInfo] = [] - torch_inductor_options = thunder_options.pop("torch_inductor_options", {}) - self.thunder_options = thunder_options self._thunder_jit = partial(jit, **thunder_options) - self._torch_compile = partial(torch.compile, **torch_inductor_options) + self._torch_compile = torch.compile def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): gm = remove_empty_autocast(gm) @@ -127,3 +116,32 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes f"graph{graph_idx}_{cur_name}", use_pytest_benchmark, ) + + +def thunderfx(fn: Callable, /, **kwargs) -> Callable: + """Compiles a callable (function or model) by using Thunder as the backend of :func:`torch.compile` + Args: + fn: A :class:`~torch.nn.Module` or a function to compile. + Keyword Args: + **kwargs: a dictionary of options to pass to :func:`torch.compile` and :func:`thunder.jit`. + Returns: + The compiled callable + """ + import thunder + + torch_compile_kwarg_names = inspect.getfullargspec(torch.compile).kwonlyargs + thunder_jit_kwarg_names = inspect.getfullargspec(thunder.jit).kwonlyargs + overlap = [kwarg_name for kwarg_name in thunder_jit_kwarg_names if kwarg_name in torch_compile_kwarg_names] + check( + not overlap, + lambda: f"There are overlapping kwargs between thunder.jit and torch.compile: {overlap}", + ValueError, + ) + + torch_compile_options = {k: v for k, v in kwargs.items() if k in torch_compile_kwarg_names} + thunder_options = {k: v for k, v in kwargs.items() if k not in torch_compile_kwarg_names} + + backend = ThunderCompiler(**thunder_options) + compiled = torch.compile(fn, backend=backend, **torch_compile_options) + compiled._backend = backend + return compiled diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 0ae470b1fc..0c51fce3cf 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -10,7 +10,7 @@ from looseversion import LooseVersion from thunder import dtypes -from thunder.dynamo import ThunderCompiler +from thunder.dynamo import ThunderCompiler, thunderfx from thunder.dynamo.utils import CompilerType from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking from thunder import last_traces @@ -930,3 +930,25 @@ def check(file_name, cmd): cmd = "pytest" if use_pytest_benchmark else "python" for fname in [s1, s2, s3]: check(fname, cmd) + + +@requiresCUDA +def test_thunderfx(): + def foo(x): + return torch.sin(x) + torch.cos(x) + + x = torch.randn(4, 4, device="cuda", requires_grad=True) + cfoo = thunderfx(foo) + cfoo(x) + thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns + assert len(thunder_compiled_fns) == 1 + assert last_traces(thunder_compiled_fns[0]) + + from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform + + cfoo = thunderfx(foo, dynamic=True, transforms=[NvtxProfileTransform()]) + cfoo(x) + thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns + assert len(thunder_compiled_fns) == 1 + trc = last_traces(thunder_compiled_fns[-1])[-1] + assert any(bsym.sym.id == "nvtx_range_push" for bsym in trc.bound_symbols)