Skip to content

Commit

Permalink
move the thunderfx related functions into dynamo
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Oct 9, 2024
1 parent 5bf5f69 commit b3b3076
Showing 7 changed files with 239 additions and 252 deletions.
184 changes: 0 additions & 184 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
from functools import partial
from numbers import Number
from typing import Any
from pytest_benchmark.fixture import BenchmarkFixture

import torch
import torch.multiprocessing as mp
@@ -24,7 +23,6 @@
import thunder.executors as executors
import thunder.torch as ltorch
from thunder.core.transforms import grad, clear_grads, populate_grads
from thunder.core.utils import check
from thunder.executors.apexex import apex_ex, apex_entropy_available
from thunder.executors.cudnn_layernormex import cudnn_layernorm_ex
from thunder.executors.cudnnex import cudnn_ex, cudnn_available
@@ -3013,185 +3011,3 @@ def fn(self) -> Callable:
# if args.listbenchmarks:
# list_benchmarks(use_classname=False)
# sys.exit(0)

from thunder.dynamo.utils import _concrete_shape
from thunder.dynamo.compiler import ThunderCompiler
from itertools import chain

GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS = ("GraphID", "SplitModuleName", "executor")


def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]:
check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError)
check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError)
example_value = node.meta["example_value"]

if isinstance(example_value, torch.Tensor):
return (
make_tensor(
_concrete_shape(example_value),
dtype=example_value.dtype,
device=example_value.device,
requires_grad=example_value.requires_grad,
),
)
elif isinstance(example_value, tuple):
return tuple(
make_tensor(
_concrete_shape(e_v),
dtype=e_v.dtype,
device=e_v.device,
requires_grad=e_v.requires_grad,
)
for e_v in example_value
)
else:
raise TypeError(
"The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors."
)


class ThunderCompilerGraphBenchmarking(ThunderCompiler):
_executors = (
"eager",
"inductor",
"thunder",
)

def __init__(
self,
bench: BenchmarkFixture,
executors: Sequence[str],
**thunder_options,
):
"""
This class acts as a backend for the `torch.compile` function, facilitating the benchmarking of each `fx.GraphModule` produced by Thunder dynamo splitter.
Each `fx.GraphModule` instance is executed by the specified executors and benchmarked using `pytest-benchmark`.
Keyword arguments:
bench: the BenchmarkFixture created by `pytest_benchmark`
executors: List of executors to compare. Supported executors include: 'eager', 'inductor', and 'thunder'. If None, defaults to all available executors.
thunder_options: a dictionary of options to pass to `thunder.jit`. Besides all the arguments to `thunder.jit`,
it accepts `torch_inductor_options` which are passed to `torch.compile` if part of the graph
is not supported by thunder.
Example:
```
# script.py
import torch
import thunder
from thunder.benchmarks import ThunderCompilerGraphBenchmarking
def func(x):
x = torch.sin(x)
if x.sum() > 0:
return x + 1
else:
return x - 1
def test_func(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["eager", "thunder"])
compiled = torch.compile(backend=backend)(func)
x = torch.ones(2, requires_grad=True).cuda()
compiled(x)
```
Note: Ensure the pytest configuration file (`thunder/benchmarks/conftest.py`) is present in the same directory as `script.py` to provide the grouping customization.
.
├── script.py
├── conftest.py
Usage:
To run the benchmark test and group the results by split module, execute the following command:
`pytest script.py --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`
Dynamo segments the graph into two subgraphs, each identified by the 'GraphID[id]' field in the test name.
Each subgraph contains a single split module, processed by the Thunder-defined splitter,
which corresponds to the 'SplitModuleName[split_module_name]' field.
The currently active executor is indicated by the 'executor[executor_name]'.
With `--benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`, the test cases are grouped based on GraphID and SplitModuleName,
allowing for performance comparison between different executors (e.g., 'eager' vs. 'thunder').
--------------------------------------------------------------------------- benchmark 'GraphID=GraphID[1] SplitModuleName=SplitModuleName[thunder_1]': 2 tests ---------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[eager] 12.6325 (1.0) 14.6452 (1.0) 12.8461 (1.0) 0.3345 (1.0) 12.7634 (1.0) 0.0794 (1.0) 44;56 77.8446 (1.0) 795 100
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder] 67.3176 (5.33) 97.5824 (6.66) 70.5751 (5.49) 4.5239 (13.53) 69.3277 (5.43) 1.3885 (17.48) 114;125 14.1693 (0.18) 1501 10
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------- benchmark 'GraphID=GraphID[2] SplitModuleName=SplitModuleName[thunder_1]': 2 tests ---------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func-GraphID[2]-SplitModuleName[thunder_1]-executor[eager] 5.6229 (1.0) 7.6670 (1.0) 5.7683 (1.0) 0.3353 (1.0) 5.6884 (1.0) 0.0291 (1.0) 88;146 173.3627 (1.0) 1793 100
test_func-GraphID[2]-SplitModuleName[thunder_1]-executor[thunder] 63.2247 (11.24) 85.5654 (11.16) 66.3187 (11.50) 3.5975 (10.73) 65.4071 (11.50) 1.3760 (47.28) 97;117 15.0787 (0.09) 1584 10
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
"""
super().__init__(**thunder_options)
self.bench = bench
if not executors:
self.executors = ThunderCompilerGraphBenchmarking._executors
else:
check(
all(ex in ThunderCompilerGraphBenchmarking._executors for ex in executors),
lambda: f"ThunderCompilerGraphBenchmarking only supports the following executor names: {ThunderCompilerGraphBenchmarking._executors} ",
)
self.executors = executors
self.graph_idx = 0

def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args) -> None:
from thunder.benchmarks.targets import record_peak_allocated_memory, MAX_ALLOCATED_MEMORY_KEYWORD

for ex in self.executors:
# Uses the already compiled module if it is compiled with the expected executor
if name.startswith(ex):
fn = self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions[gm].compiled_fn
else:
if ex == "thunder":
fn = self._thunder_jit(gm)
elif ex == "inductor":
fn = self._torch_compile(gm)
else:
fn = gm
with record_peak_allocated_memory(self.bench):
self.bench(fn, *sample_args)
# BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150)
# Adds the graph number, split module name and executor suffix to the name string
gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex}]"
assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info
assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info
# NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark.
# Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats.
self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = (
self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD)
)

# when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error:
# `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)`
# Ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L115-L118
# Here manually set the BenchmarkFixture._mode=None to avoid it
self.bench._mode = None

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
split_module = super().__call__(gm, sample_args)
compiled_functions_to_submodule = {
v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items()
}
for node in split_module.graph.nodes:
target = node.target
# Benchmarks the modules produced by the splitter.
if isinstance(target, str) and target.startswith(("thunder_", "inductor_")):
check(
hasattr(split_module, target),
lambda: f"the submodule {target} does not exist in {split_module}",
ValueError,
)
cur_module = getattr(split_module, target)
cur_nodes = cur_module.graph.nodes
# Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = chain(*map(_get_example_inputs_from_placeholder, placeholders))
# Runs the benchmark on the original module with the generated random inputs
self.run_bench(compiled_functions_to_submodule[cur_module], target, *args)
self.graph_idx += 1
return split_module
39 changes: 0 additions & 39 deletions thunder/benchmarks/conftest.py

This file was deleted.

26 changes: 0 additions & 26 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,6 @@
torch_compile_executor,
torch_executor,
thunder_transformerengine_executor,
ThunderCompilerGraphBenchmarking,
)
from thunder.core.interpreter import interpret

@@ -885,28 +884,3 @@ def test_torchbench_canary(benchmark, module_name, executor, compute_type: Compu
fn = executor(b.fn())

benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


#
# Thunder dynamo graph-by-graph benchmarks
#

def test_dynamo_LlamaMLPBenchmark(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"])

bench: Benchmark = LlamaMLPBenchmark(
config="Llama-2-7b-hf",
batchdims=(2,),
device="cuda:0",
dtype=thunder.bfloat16,
requires_grad=True,
)

args, kwargs = bench.make_batch()
# Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
fn = torch._dynamo.optimize(backend=backend)(bench.fn())
fn(*args, **kwargs)

# Avoid torch._dynamo hit config.cache_size_limit (8)
torch._dynamo.reset()
3 changes: 2 additions & 1 deletion thunder/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from thunder.dynamo.compiler import ThunderCompiler
from thunder.dynamo.compiler import ThunderCompiler, ThunderCompilerGraphBenchmarking


__all__ = [
"ThunderCompiler",
"ThunderCompilerGraphBenchmarking",
]
Loading
Oops, something went wrong.

0 comments on commit b3b3076

Please sign in to comment.