Skip to content

Add ThunderFX as a benchmark executor #1249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.testing import make_tensor

import thunder
import thunder.dynamo
import thunder.core.devices as Devices
import thunder.core.dtypes as dtypes
import thunder.executors as executors
Expand Down Expand Up @@ -707,6 +708,13 @@ def torch_compile_executor(fn: Callable) -> Callable:
return torch.compile(fn)


def thunderfx_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
backend = thunder.dynamo.ThunderCompiler()
torch._dynamo.reset()
return torch.compile(fn, backend=backend)


def thunder_torch_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
return thunder.jit(fn, executors=[thunder.pytorch_executor])
Expand Down
21 changes: 12 additions & 9 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
thunder_cudnn_executor,
thunder_cudnn_nvfuser_executor,
thunder_executor,
thunderfx_executor,
thunder_sdpa_torch_compile_nvfuser_executor,
torch_compile_executor,
torch_executor,
Expand Down Expand Up @@ -151,17 +152,19 @@ def interpreter_fwd(module: Callable):
return fn_


executors = (
torch_executor,
torch_compile_executor,
thunder_executor,
)
executors = (torch_executor, torch_compile_executor, thunder_executor)
Comment on lines -154 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What formatter applied this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk ruff configured by this

[tool.ruff]
. Let me know if in your case this doesn't match expectation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just curious about this one-line change, usually, a pre-commit job applies changes only to a modified part of the code (at least the one from CI).

executors_ids = (
"torch",
"torch.compile",
"thunder",
)

torchbench_executors = (*executors, thunderfx_executor)
torchbench_executors_ids = (
*executors_ids,
"thunderfx",
)

apex_executors = (thunder_apex_executor, thunder_apex_nvfuser_executor)
apex_executors_ids = ("thunder+apex-grad", "thunder+apex+nvfuser-grad")

Expand Down Expand Up @@ -841,8 +844,8 @@ def test_resnet50(benchmark, executor: Callable, compute_type: ComputeType):
)
@pytest.mark.parametrize(
"executor,",
executors,
ids=executors_ids,
torchbench_executors,
ids=torchbench_executors_ids,
)
@parametrize_compute_type
def test_torchbench(benchmark, module_name, executor, compute_type: ComputeType):
Expand All @@ -867,8 +870,8 @@ def test_torchbench(benchmark, module_name, executor, compute_type: ComputeType)
)
@pytest.mark.parametrize(
"executor,",
executors,
ids=executors_ids,
torchbench_executors,
ids=torchbench_executors_ids,
)
@parametrize_compute_type
def test_torchbench_canary(benchmark, module_name, executor, compute_type: ComputeType):
Expand Down
Loading