Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph-by-graph benchmarking of dynamo.ThunderCompiler #1066

Merged
merged 21 commits into from
Oct 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support the splitter
kiya00 committed Oct 11, 2024

Verified

This commit was signed with the committer’s verified signature.
snyk-bot Snyk bot
commit 8fdfd8da1290733db0d2612d717026cc2f03ca82
125 changes: 95 additions & 30 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -709,11 +709,6 @@ def torch_compile_executor(fn: Callable) -> Callable:
return torch.compile(fn)


def torch_compile_without_reset_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
return torch.compile(fn)


def thunder_torch_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
return thunder.jit(fn, executors=[thunder.pytorch_executor])
@@ -3019,13 +3014,47 @@ def fn(self) -> Callable:
# list_benchmarks(use_classname=False)
# sys.exit(0)

from thunder.dynamo.utils import _concrete_shape
from thunder.dynamo.compiler import ThunderCompiler
from itertools import chain


def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]:
check(node.op == "placeholder", lambda: f"The node must be placeholder type")
assert "example_value" in node.meta
example_value = node.meta["example_value"]

class DynamoBackendBenchmarking:
_executors = {
"torch": torch_executor,
"torch.compile": torch_compile_without_reset_executor,
"thunder": thunder_executor,
}
if isinstance(example_value, torch.Tensor):
return (
make_tensor(
_concrete_shape(example_value),
dtype=example_value.dtype,
device=example_value.device,
requires_grad=example_value.requires_grad,
),
)
elif isinstance(example_value, tuple):
return tuple(
make_tensor(
_concrete_shape(e_v),
dtype=e_v.dtype,
device=e_v.device,
requires_grad=e_v.requires_grad,
)
for e_v in example_value
)
else:
raise TypeError(
f"The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors."
)


class DynamoBackendBenchmarking(ThunderCompiler):
_executors = (
"eager",
"inductor",
"thunder",
)

def __init__(
self,
@@ -3039,8 +3068,10 @@ def __init__(

Keyword arguments:
bench: the BenchmarkFixture created by `pytest_benchmark`
executors: List of executors to use. Supported executors include: 'torch', 'torch.compile', and 'thunder'. If None, defaults to all available executors.
thunder_options: a dictionary of options to pass to `thunder.jit`.
executors: List of executors to use. Supported executors include: 'eager', 'inductor', and 'thunder'. If None, defaults to all available executors.
thunder_options: a dictionary of options to pass to `thunder.jit`. Besides all the arguments to `thunder.jit`,
it accepts `torch_inductor_options` which are passed to `torch.compile` if part of the graph
is not supported by thunder.

Example:
>>> import torch
@@ -3055,45 +3086,79 @@ def __init__(
... return x - 1
...
>>> def test_func(benchmark):
... backend = DynamoBackendBenchmarking(benchmark)
... backend = DynamoBackendBenchmarking(benchmark, executors=["eager"])
... compiled = torch.compile(backend=backend)(func)
... x = torch.ones(2, requires_grad=True).cuda()
... compiled(x)
Running the example with `pytest script.py --benchmark-sort="name"` will produce benchmarking results.

# Dynamo segments the graph into two subgraphs, each identified by the 'GraphId[id]' field in the test name.
# Each subgraph contains a single split module, processed by the Thunder-defined splitter,
# which corresponds to the 'SplitModuleName[split_module_name]' field.
# The currently active executor is indicated by the '_eager' suffix.
----------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func_GraphId[1]_SplitModuleName[thunder_1]_eager 11.8791 (2.23) 14.2842 (1.95) 12.2122 (2.23) 0.3507 (1.05) 12.1247 (2.25) 0.1275 (2.36) 46;71 81.8853 (0.45) 841 100
test_func_GraphId[2]_SplitModuleName[thunder_1]_eager 5.3239 (1.0) 7.3174 (1.0) 5.4771 (1.0) 0.3326 (1.0) 5.3976 (1.0) 0.0540 (1.0) 94;157 182.5784 (1.0) 1879 100
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
"""
self.thunder_options = thunder_options
super().__init__(**thunder_options)
self.bench = bench
if not executors:
self.executors = DynamoBackendBenchmarking._executors.values()
self.executor_ids = list(DynamoBackendBenchmarking._executors.keys())
self.executors = DynamoBackendBenchmarking._executors
else:
check(
all(ex in DynamoBackendBenchmarking._executors for ex in executors),
lambda: f"DynamoBackendBenchmarking only supports the following executor names: {list(DynamoBackendBenchmarking._executors.keys())} ",
lambda: f"DynamoBackendBenchmarking only supports the following executor names: {DynamoBackendBenchmarking._executors} ",
)
self.executors = (DynamoBackendBenchmarking._executors[ex] for ex in executors)
self.executor_ids = executors
self.gm2stats = {}
self.executors = executors
self.graph_idx = 0

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
def get_bench_result(self, gm: torch.fx.GraphModule, name: str, sample_args: list[torch.Tensor]):
from thunder.benchmarks.targets import record_peak_allocated_memory

gm.real_recompile()

for idx, ex in enumerate(self.executors):
fn = ex(gm, **self.thunder_options) if self.executor_ids[idx] == "thunder" else ex(gm)
for ex in self.executors:
# Uses the already compiled module if it is compiled with the expected executor
if name.startswith(ex):
fn = self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions[gm].compiled_fn
else:
if ex == "thunder":
fn = self._thunder_jit(gm)
elif ex == "inductor":
fn = self._torch_compile(gm)
else:
fn = gm
with record_peak_allocated_memory(self.bench):
self.bench(fn, *sample_args)
# BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150)
# Adds the graph number to the name string
self.bench.stats.name = self.bench.stats.name + f"_{self.executor_ids[idx]}.graph{self.graph_idx}"
self.gm2stats[gm] = self.bench.stats
# Adds the graph number, split module name and executor suffix to the name string
self.bench.stats.name = (
self.bench.stats.name + f"_GraphId[{self.graph_idx+1}]_SplitModuleName[{name}]" + f"_{ex}"
)

# when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error:
# `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)`
# Ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L115-L118
# Here manually set the BenchmarkFixture._mode=None to avoid it
self.bench._mode = None

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
split_module = super().__call__(gm, sample_args)
compiled_functions_to_submodule = {
v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items()
}
for node in split_module.graph.nodes:
target = node.target
# Benchmarks the modules produced by splitter
if isinstance(target, str) and target.startswith(("thunder_", "inductor_")):
assert hasattr(split_module, target)
cur_module = getattr(split_module, target)
cur_nodes = cur_module.graph.nodes
# Gets a random generated inputs to the current module based on the faketensor in the 'example_value' of the placeholder
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = list(chain(*map(_get_example_inputs_from_placeholder, placeholders)))

self.get_bench_result(compiled_functions_to_submodule[cur_module], target, args)
self.graph_idx += 1
return gm
return split_module
13 changes: 2 additions & 11 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
@@ -35,7 +35,6 @@
thunder_executor,
thunder_sdpa_torch_compile_nvfuser_executor,
torch_compile_executor,
torch_compile_without_reset_executor,
torch_executor,
thunder_transformerengine_executor,
DynamoBackendBenchmarking,
@@ -887,16 +886,8 @@ def test_torchbench_canary(benchmark, module_name, executor, compute_type: Compu
benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


@pytest.mark.parametrize(
"executor,",
(
"torch",
"torch.compile",
"thunder",
),
)
def test_dynamo_LlamaMLPBenchmark(benchmark, executor: Callable):
backend = DynamoBackendBenchmarking(benchmark, [executor])
def test_dynamo_LlamaMLPBenchmark(benchmark):
backend = DynamoBackendBenchmarking(benchmark)

bench: Benchmark = LlamaMLPBenchmark(
config="Llama-2-7b-hf",