Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph-by-graph benchmarking of dynamo.ThunderCompiler #1066

Merged
merged 21 commits into from
Oct 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add comments
kiya00 committed Oct 11, 2024
commit 048451f4f8b68a72f4da2484a05a5d4ce60814e0
36 changes: 20 additions & 16 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -3020,8 +3020,8 @@ def fn(self) -> Callable:


def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]:
check(node.op == "placeholder", lambda: f"The node must be placeholder type")
assert "example_value" in node.meta
check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError)
check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError)
example_value = node.meta["example_value"]

if isinstance(example_value, torch.Tensor):
@@ -3049,7 +3049,7 @@ def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]:
)


class DynamoBackendBenchmarking(ThunderCompiler):
class ThunderCompilerGraphBenchmarking(ThunderCompiler):
_executors = (
"eager",
"inductor",
@@ -3063,7 +3063,7 @@ def __init__(
**thunder_options,
):
"""
This class acts as a backend for the `torch.compile` function, facilitating the benchmarking of each `fx.GraphModule`.
This class acts as a backend for the `torch.compile` function, facilitating the benchmarking of each `fx.GraphModule` produced by Thunder dynamo splitter.
Each `fx.GraphModule` instance is executed by the specified executors and benchmarked using `pytest_benchmark`.

Keyword arguments:
@@ -3076,7 +3076,7 @@ def __init__(
Example:
>>> import torch
>>> import thunder
>>> from thunder.benchmarks import DynamoBackendBenchmarking
>>> from thunder.benchmarks import ThunderCompilerGraphBenchmarking
>>>
>>> def func(x):
... x = torch.sin(x)
@@ -3086,7 +3086,7 @@ def __init__(
... return x - 1
...
>>> def test_func(benchmark):
... backend = DynamoBackendBenchmarking(benchmark, executors=["eager"])
... backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["eager"])
... compiled = torch.compile(backend=backend)(func)
... x = torch.ones(2, requires_grad=True).cuda()
... compiled(x)
@@ -3106,16 +3106,16 @@ def __init__(
super().__init__(**thunder_options)
self.bench = bench
if not executors:
self.executors = DynamoBackendBenchmarking._executors
self.executors = ThunderCompilerGraphBenchmarking._executors
else:
check(
all(ex in DynamoBackendBenchmarking._executors for ex in executors),
lambda: f"DynamoBackendBenchmarking only supports the following executor names: {DynamoBackendBenchmarking._executors} ",
all(ex in ThunderCompilerGraphBenchmarking._executors for ex in executors),
lambda: f"ThunderCompilerGraphBenchmarking only supports the following executor names: {ThunderCompilerGraphBenchmarking._executors} ",
)
self.executors = executors
self.graph_idx = 0

def get_bench_result(self, gm: torch.fx.GraphModule, name: str, sample_args: list[torch.Tensor]):
def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args) -> None:
from thunder.benchmarks.targets import record_peak_allocated_memory

for ex in self.executors:
@@ -3150,15 +3150,19 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
}
for node in split_module.graph.nodes:
target = node.target
# Benchmarks the modules produced by splitter
# Benchmarks the modules produced by the splitter.
if isinstance(target, str) and target.startswith(("thunder_", "inductor_")):
assert hasattr(split_module, target)
check(
hasattr(split_module, target),
lambda: f"the submodule {target} does not exist in {split_module}",
ValueError,
)
cur_module = getattr(split_module, target)
cur_nodes = cur_module.graph.nodes
# Gets a random generated inputs to the current module based on the faketensor in the 'example_value' of the placeholder
# Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = list(chain(*map(_get_example_inputs_from_placeholder, placeholders)))

self.get_bench_result(compiled_functions_to_submodule[cur_module], target, args)
args = chain(*map(_get_example_inputs_from_placeholder, placeholders))
# Runs the benchmark on the original module with the generated random inputs
self.run_bench(compiled_functions_to_submodule[cur_module], target, *args)
self.graph_idx += 1
return split_module
8 changes: 6 additions & 2 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@
torch_compile_executor,
torch_executor,
thunder_transformerengine_executor,
DynamoBackendBenchmarking,
ThunderCompilerGraphBenchmarking,
)
from thunder.core.interpreter import interpret

@@ -886,8 +886,12 @@ def test_torchbench_canary(benchmark, module_name, executor, compute_type: Compu
benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


#
# Thunder dynamo graph-by-graph benchmarks
#

def test_dynamo_LlamaMLPBenchmark(benchmark):
backend = DynamoBackendBenchmarking(benchmark)
backend = ThunderCompilerGraphBenchmarking(benchmark)

bench: Benchmark = LlamaMLPBenchmark(
config="Llama-2-7b-hf",