Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thunderfx API #1535

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions thunder/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from thunder.dynamo.compiler import ThunderCompiler
from thunder.dynamo.compiler import ThunderCompiler, thunderfx


__all__ = [
"ThunderCompiler",
]
__all__ = ["ThunderCompiler", "thunderfx"]
32 changes: 32 additions & 0 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from looseversion import LooseVersion
from typing import TYPE_CHECKING
import warnings
import inspect

import torch

from thunder.core.baseutils import run_once
from thunder.core.utils import safe_zip
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType
from thunder.dynamo.splitter import _splitter
from thunder.core.utils import check

if TYPE_CHECKING:
from thunder.dynamo.utils import SubgraphInfo
from os import PathLike
from collections.abc import Callable


@run_once
Expand Down Expand Up @@ -127,3 +130,32 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes
f"graph{graph_idx}_{cur_name}",
use_pytest_benchmark,
)


def thunderfx(fn: Callable, /, **kwargs) -> Callable:
"""Compiles a callable (function or model) by using Thunder as the backend of :func:`torch.compile`
Args:
fn: A :class:`~torch.nn.Module` or a function to compile.
Keyword Args:
**kwargs: a dictionary of options to pass to :func:`torch.compile` and :func:`thunder.jit`.
Returns:
The compiled callable
"""
import thunder

torch_compile_kwarg_names = inspect.getfullargspec(torch.compile).kwonlyargs
thunder_jit_kwarg_names = inspect.getfullargspec(thunder.jit).kwonlyargs
overlap = [kwarg_name for kwarg_name in thunder_jit_kwarg_names if kwarg_name in torch_compile_kwarg_names]
mruberry marked this conversation as resolved.
Show resolved Hide resolved
check(
not overlap,
lambda: f"There are overlapping kwargs between thunder.jit and torch.compile: {overlap}",
ValueError,
)

torch_compile_options = {k: v for k, v in kwargs.items() if k in torch_compile_kwarg_names}
thunder_options = {k: v for k, v in kwargs.items() if k not in torch_compile_kwarg_names}

backend = ThunderCompiler(**thunder_options)
mruberry marked this conversation as resolved.
Show resolved Hide resolved
compiled = torch.compile(fn, backend=backend, **torch_compile_options)
compiled._backend = backend
return compiled
24 changes: 23 additions & 1 deletion thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from looseversion import LooseVersion

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
from thunder.dynamo import ThunderCompiler, thunderfx
from thunder.dynamo.utils import CompilerType
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking
from thunder import last_traces
Expand Down Expand Up @@ -930,3 +930,25 @@ def check(file_name, cmd):
cmd = "pytest" if use_pytest_benchmark else "python"
for fname in [s1, s2, s3]:
check(fname, cmd)


@requiresCUDA
def test_thunderfx():
def foo(x):
return torch.sin(x) + torch.cos(x)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
cfoo = thunderfx(foo)
cfoo(x)
thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns
assert len(thunder_compiled_fns) == 1
assert last_traces(thunder_compiled_fns[0])

from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

cfoo = thunderfx(foo, dynamic=True, transforms=[NvtxProfileTransform()])
mruberry marked this conversation as resolved.
Show resolved Hide resolved
cfoo(x)
thunder_compiled_fns = cfoo._backend.subgraph_infos[0].thunder_compiled_fns
assert len(thunder_compiled_fns) == 1
trc = last_traces(thunder_compiled_fns[-1])[-1]
assert any(bsym.sym.id == "nvtx_range_push" for bsym in trc.bound_symbols)
Loading