Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph-by-graph benchmarking of dynamo.ThunderCompiler #1066

Merged
merged 21 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"phi-2",
]
RUN_ALL_CONFIGS = os.environ.get("THUNDER_BENCH_RUN_ALL_CONFIGS", "0") == "1"
MAX_ALLOCATED_MEMORY_KEYWORD = "max_allocated_memory_MB"


class ComputeType(Enum):
Expand Down Expand Up @@ -113,7 +114,7 @@ def deco(old_timer):
@functools.wraps(old_timer)
def timer():
ret = old_timer()
benchmark.extra_info["max_allocated_memory_MB"] = torch.cuda.max_memory_allocated() / (1024 * 1024.0)
benchmark.extra_info[MAX_ALLOCATED_MEMORY_KEYWORD] = torch.cuda.max_memory_allocated() / (1024 * 1024.0)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
torch.cuda.reset_peak_memory_stats()
return ret

Expand Down
146 changes: 146 additions & 0 deletions thunder/dynamo/compiler_graph_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations
from itertools import chain
from pytest_benchmark.fixture import BenchmarkFixture
from typing import TYPE_CHECKING

import torch
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.utils import _get_example_inputs_from_placeholder
from thunder.core.utils import check


if TYPE_CHECKING:
from collections.abc import Sequence

GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS = ("GraphID", "SplitModuleName", "executor")


class ThunderCompilerGraphBenchmarking(ThunderCompiler):
_executors = (
"eager",
"inductor",
"thunder",
)

def __init__(
self,
bench: BenchmarkFixture,
executors: Sequence[str],
**thunder_options,
):
"""
This class acts as a backend for the :func:`torch.compile` function, facilitating the benchmarking of each :class:`torch.fx.GraphModule` produced by Thunder dynamo splitter.
Each :class:`torch.fx.GraphModule` instance is executed by the specified executors and benchmarked using `pytest-benchmark`.

Args:
bench: the BenchmarkFixture created by ``pytest_benchmark``
executors: list of executors to compare. Supported executors include: 'eager', 'inductor', and 'thunder'. If None, defaults to all available executors.
**thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`,
it accepts `torch_inductor_options` which are passed to :func:`torch.compile` if part of the graph
is not supported by thunder.

Example:
.. code-block:: python

# script.py
import torch
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking

def func(x):
x = torch.sin(x)
if x.sum() > 0:
return x + 1
else:
return x - 1

def test_func(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["eager", "thunder"])
compiled = torch.compile(backend=backend)(func)
x = torch.ones(2, requires_grad=True).cuda()
compiled(x)

Note:
Ensure the pytest configuration file (`thunder/tests/conftest.py`) is present in the same directory as `script.py` to provide the grouping customization.

To run the benchmark test and group the results by split module, execute the following command:
`pytest script.py --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`

In this example, Dynamo segments the graph into two subgraphs, each identified by the 'GraphID[id]' field in the test name.
Each subgraph contains a single split module, processed by the Thunder-defined splitter,
which corresponds to the 'SplitModuleName[split_module_name]' field.
The currently active executor is indicated by the 'executor[executor_name]'.
With `--benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`, the test cases are grouped based on GraphID and SplitModuleName,
allowing for performance comparison between different executors (e.g., 'eager' vs. 'thunder').
"""
super().__init__(**thunder_options)
self.bench = bench
if not executors:
self.executors = ThunderCompilerGraphBenchmarking._executors
else:
check(
all(ex in ThunderCompilerGraphBenchmarking._executors for ex in executors),
lambda: f"ThunderCompilerGraphBenchmarking only supports the following executor names: {ThunderCompilerGraphBenchmarking._executors} ",
)
self.executors = executors
self.graph_idx = 0

def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args):
from thunder.benchmarks.targets import record_peak_allocated_memory, MAX_ALLOCATED_MEMORY_KEYWORD

for ex in self.executors:
# Uses the already compiled module if it is compiled with the expected executor
if name.startswith(ex):
fn = self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions[gm].compiled_fn
else:
if ex == "thunder":
# The subgraph whose name starts with "inductor" is not supported by the Thunder backend.
if name.startswith("inductor"):
continue
fn = self._thunder_jit(gm)
elif ex == "inductor":
fn = self._torch_compile(gm)
else:
fn = gm
with record_peak_allocated_memory(self.bench):
self.bench(fn, *sample_args)
# BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150)
# Adds the graph number, split module name and executor suffix to the name string
gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex}]"
assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info
assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info
# NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark.
# Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats.
self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = (
self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD)
)

# when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error:
# `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)`
# Ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L115-L118
# Here manually set the BenchmarkFixture._mode=None to avoid it
self.bench._mode = None

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
split_module = super().__call__(gm, sample_args)
compiled_functions_to_submodule = {
v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items()
}
for node in split_module.graph.nodes:
target = node.target
# Benchmarks the modules produced by the splitter.
if isinstance(target, str) and target.startswith(("thunder_", "inductor_")):
check(
hasattr(split_module, target),
lambda: f"the submodule {target} does not exist in {split_module}",
ValueError,
)
cur_module = getattr(split_module, target)
cur_nodes = cur_module.graph.nodes
# Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = chain(*map(_get_example_inputs_from_placeholder, placeholders))
# Runs the benchmark on the original module with the generated random inputs
self.run_bench(compiled_functions_to_submodule[cur_module], target, *args)
self.graph_idx += 1
return split_module
41 changes: 41 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import dataclasses
import inspect
import itertools
import warnings

import torch

from thunder.torch.default_torch_ops import torch_auto_registered_ops
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.langctx import torchctx
from thunder.core.utils import check

if TYPE_CHECKING:
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -377,3 +379,42 @@ def recompile_graph(gm: torch.fx.GraphModule):
if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
return gm.real_recompile()
return gm.recompile()


def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]:
from thunder.tests.make_tensor import make_tensor

check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError)
# Prefers to use actual example value in GraphArg if available
if "grapharg" in node.meta:
example_value = node.meta["grapharg"].example
if isinstance(example_value, torch.Tensor):
return (example_value.detach().clone().requires_grad_(example_value.requires_grad),)

check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError)
example_value = node.meta["example_value"]

if isinstance(example_value, torch.Tensor):
sz = _concrete_shape(example_value)
return (
make_tensor(
sz,
dtype=example_value.dtype,
device=example_value.device,
requires_grad=example_value.requires_grad,
).as_strided(sz, example_value.stride()),
)
elif isinstance(example_value, tuple):
return tuple(
make_tensor(
_concrete_shape(e_v),
dtype=e_v.dtype,
device=e_v.device,
requires_grad=e_v.requires_grad,
).as_strided(_concrete_shape(e_v), e_v.stride())
for e_v in example_value
)
else:
raise TypeError(
"The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors."
)
39 changes: 39 additions & 0 deletions thunder/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from collections import defaultdict
import pytest_benchmark
from thunder.dynamo.compiler_graph_benchmark import GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS


@pytest.hookimpl(hookwrapper=True)
def pytest_benchmark_group_stats(config, benchmarks, group_by):
"""
The function customize the behavior for ThunderCompilerGraphBenchmarking.
The custom grouping function is only invoked when the `--benchmark-group-by`
option is set to 'graph-by-graph:param:GraphID,param:SplitModuleName'.
For an example, refer to the comment section in `ThunderCompilerGraphBenchmarking`.

Reference: https://pytest-benchmark.readthedocs.io/en/latest/hooks.html#pytest_benchmark.hookspec.pytest_benchmark_group_stats
"""
prefix = "graph-by-graph:"
outcome = yield
if group_by.startswith(prefix):
group_by = group_by[len(prefix) :]
for bench in benchmarks:
if bench["params"] is None:
bench["params"] = {}
# The benchs with the same `params`` share the same dict
# We need to create a deepcopy of the original dictionary to add parameters specific to each graph.
else:
bench["params"] = bench["params"].copy()
if bench["param"] is None:
bench["param"] = ""

name = bench["name"]
gid, module_name, ex = name.split("-")[-3:]
# Add the "GraphID", "SplitModuleName","executor" as params in benchmark
gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
bench["params"].update({gid_key: gid, module_name_key: module_name, ex_key: ex})
bench["param"] += f"-{gid}-{module_name}-{ex}"

result = pytest_benchmark.plugin.pytest_benchmark_group_stats(config, benchmarks, group_by)
outcome.force_result(result)
45 changes: 45 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking
from thunder import last_traces
from thunder.tests.bf16 import device_supports_bf16
from thunder.tests.framework import (
Expand Down Expand Up @@ -465,3 +466,47 @@ def func(x):
actual_grad = torch.autograd.grad(actual, x, g)
expected_grad = torch.autograd.grad(expected, x, g)
torch.testing.assert_close(actual_grad, expected_grad)


# Sample command to run the benchmark using ThunderCompilerGraphBenchmarking
# pytest thunder/tests/test_dynamo.py -k test_ThunderCompilerGraphBenchmarking_groupby --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'
# For more details, see :class:`thunder.dynamo.compiler_graph_benchmark.ThunderCompilerGraphBenchmarking`
# NOTE: The conftest.py file customizes the benchmark grouping behavior for ThunderCompilerGraphBenchmarking.
# It must be located in the same folder as the test file to ensure the configuration.
@requiresCUDA
def test_ThunderCompilerGraphBenchmarking_LlamaMLPBenchmark(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"])
from thunder.benchmarks import LlamaMLPBenchmark, Benchmark

bench: Benchmark = LlamaMLPBenchmark(
config="Llama-2-7b-hf",
batchdims=(2,),
device="cuda:0",
requires_grad=True,
)

args, kwargs = bench.make_batch()
# Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
fn = torch._dynamo.optimize(backend=backend)(bench.fn())
fn(*args, **kwargs)


@requiresCUDA
def test_ThunderCompilerGraphBenchmarking_groupby(benchmark):
def f(x, y):
x = torch.sin(x)
if x.sum() > 0:
x = x.exp()
y = torch.sinc(x) + torch.cos(y)
return y + 1
else:
y = y.exp()
x = torch.sinc(y) + torch.cos(x)
return x - 1

backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"])
compiled = torch.compile(backend=backend)(f)
x = torch.ones(2, requires_grad=True).cuda()
y = torch.ones(2, requires_grad=True).cuda()
compiled(x, y)
Loading