diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 94f892b519..36f09ff060 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -56,6 +56,7 @@ "phi-2", ] RUN_ALL_CONFIGS = os.environ.get("THUNDER_BENCH_RUN_ALL_CONFIGS", "0") == "1" +MAX_ALLOCATED_MEMORY_KEYWORD = "max_allocated_memory_MB" class ComputeType(Enum): @@ -113,7 +114,7 @@ def deco(old_timer): @functools.wraps(old_timer) def timer(): ret = old_timer() - benchmark.extra_info["max_allocated_memory_MB"] = torch.cuda.max_memory_allocated() / (1024 * 1024.0) + benchmark.extra_info[MAX_ALLOCATED_MEMORY_KEYWORD] = torch.cuda.max_memory_allocated() / (1024 * 1024.0) torch.cuda.reset_peak_memory_stats() return ret diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py new file mode 100644 index 0000000000..7f7fbccf6f --- /dev/null +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -0,0 +1,146 @@ +from __future__ import annotations +from itertools import chain +from pytest_benchmark.fixture import BenchmarkFixture +from typing import TYPE_CHECKING + +import torch +from thunder.dynamo import ThunderCompiler +from thunder.dynamo.utils import _get_example_inputs_from_placeholder +from thunder.core.utils import check + + +if TYPE_CHECKING: + from collections.abc import Sequence + +GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS = ("GraphID", "SplitModuleName", "executor") + + +class ThunderCompilerGraphBenchmarking(ThunderCompiler): + _executors = ( + "eager", + "inductor", + "thunder", + ) + + def __init__( + self, + bench: BenchmarkFixture, + executors: Sequence[str], + **thunder_options, + ): + """ + This class acts as a backend for the :func:`torch.compile` function, facilitating the benchmarking of each :class:`torch.fx.GraphModule` produced by Thunder dynamo splitter. + Each :class:`torch.fx.GraphModule` instance is executed by the specified executors and benchmarked using `pytest-benchmark`. + + Args: + bench: the BenchmarkFixture created by ``pytest_benchmark`` + executors: list of executors to compare. Supported executors include: 'eager', 'inductor', and 'thunder'. If None, defaults to all available executors. + **thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`, + it accepts `torch_inductor_options` which are passed to :func:`torch.compile` if part of the graph + is not supported by thunder. + + Example: + .. code-block:: python + + # script.py + import torch + from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking + + def func(x): + x = torch.sin(x) + if x.sum() > 0: + return x + 1 + else: + return x - 1 + + def test_func(benchmark): + backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["eager", "thunder"]) + compiled = torch.compile(backend=backend)(func) + x = torch.ones(2, requires_grad=True).cuda() + compiled(x) + + Note: + Ensure the pytest configuration file (`thunder/tests/conftest.py`) is present in the same directory as `script.py` to provide the grouping customization. + + To run the benchmark test and group the results by split module, execute the following command: + `pytest script.py --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'` + + In this example, Dynamo segments the graph into two subgraphs, each identified by the 'GraphID[id]' field in the test name. + Each subgraph contains a single split module, processed by the Thunder-defined splitter, + which corresponds to the 'SplitModuleName[split_module_name]' field. + The currently active executor is indicated by the 'executor[executor_name]'. + With `--benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`, the test cases are grouped based on GraphID and SplitModuleName, + allowing for performance comparison between different executors (e.g., 'eager' vs. 'thunder'). + """ + super().__init__(**thunder_options) + self.bench = bench + if not executors: + self.executors = ThunderCompilerGraphBenchmarking._executors + else: + check( + all(ex in ThunderCompilerGraphBenchmarking._executors for ex in executors), + lambda: f"ThunderCompilerGraphBenchmarking only supports the following executor names: {ThunderCompilerGraphBenchmarking._executors} ", + ) + self.executors = executors + self.graph_idx = 0 + + def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args): + from thunder.benchmarks.targets import record_peak_allocated_memory, MAX_ALLOCATED_MEMORY_KEYWORD + + for ex in self.executors: + # Uses the already compiled module if it is compiled with the expected executor + if name.startswith(ex): + fn = self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions[gm].compiled_fn + else: + if ex == "thunder": + # The subgraph whose name starts with "inductor" is not supported by the Thunder backend. + if name.startswith("inductor"): + continue + fn = self._thunder_jit(gm) + elif ex == "inductor": + fn = self._torch_compile(gm) + else: + fn = gm + with record_peak_allocated_memory(self.bench): + self.bench(fn, *sample_args) + # BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150) + # Adds the graph number, split module name and executor suffix to the name string + gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS + self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex}]" + assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info + assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info + # NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark. + # Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats. + self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = ( + self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD) + ) + + # when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error: + # `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)` + # Ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L115-L118 + # Here manually set the BenchmarkFixture._mode=None to avoid it + self.bench._mode = None + + def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): + split_module = super().__call__(gm, sample_args) + compiled_functions_to_submodule = { + v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items() + } + for node in split_module.graph.nodes: + target = node.target + # Benchmarks the modules produced by the splitter. + if isinstance(target, str) and target.startswith(("thunder_", "inductor_")): + check( + hasattr(split_module, target), + lambda: f"the submodule {target} does not exist in {split_module}", + ValueError, + ) + cur_module = getattr(split_module, target) + cur_nodes = cur_module.graph.nodes + # Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node + placeholders = list(n for n in cur_nodes if n.op == "placeholder") + args = chain(*map(_get_example_inputs_from_placeholder, placeholders)) + # Runs the benchmark on the original module with the generated random inputs + self.run_bench(compiled_functions_to_submodule[cur_module], target, *args) + self.graph_idx += 1 + return split_module diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index b0ef08097e..d1041c90dc 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -5,12 +5,14 @@ import dataclasses import inspect import itertools +import warnings import torch from thunder.torch.default_torch_ops import torch_auto_registered_ops from thunder.torch import _torch_to_thunder_function_map from thunder.torch.langctx import torchctx +from thunder.core.utils import check if TYPE_CHECKING: from thunder.core.symbol import Symbol @@ -377,3 +379,42 @@ def recompile_graph(gm: torch.fx.GraphModule): if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule): return gm.real_recompile() return gm.recompile() + + +def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]: + from thunder.tests.make_tensor import make_tensor + + check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError) + # Prefers to use actual example value in GraphArg if available + if "grapharg" in node.meta: + example_value = node.meta["grapharg"].example + if isinstance(example_value, torch.Tensor): + return (example_value.detach().clone().requires_grad_(example_value.requires_grad),) + + check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError) + example_value = node.meta["example_value"] + + if isinstance(example_value, torch.Tensor): + sz = _concrete_shape(example_value) + return ( + make_tensor( + sz, + dtype=example_value.dtype, + device=example_value.device, + requires_grad=example_value.requires_grad, + ).as_strided(sz, example_value.stride()), + ) + elif isinstance(example_value, tuple): + return tuple( + make_tensor( + _concrete_shape(e_v), + dtype=e_v.dtype, + device=e_v.device, + requires_grad=e_v.requires_grad, + ).as_strided(_concrete_shape(e_v), e_v.stride()) + for e_v in example_value + ) + else: + raise TypeError( + "The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors." + ) diff --git a/thunder/tests/conftest.py b/thunder/tests/conftest.py new file mode 100644 index 0000000000..cc357a9787 --- /dev/null +++ b/thunder/tests/conftest.py @@ -0,0 +1,39 @@ +import pytest +from collections import defaultdict +import pytest_benchmark +from thunder.dynamo.compiler_graph_benchmark import GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS + + +@pytest.hookimpl(hookwrapper=True) +def pytest_benchmark_group_stats(config, benchmarks, group_by): + """ + The function customize the behavior for ThunderCompilerGraphBenchmarking. + The custom grouping function is only invoked when the `--benchmark-group-by` + option is set to 'graph-by-graph:param:GraphID,param:SplitModuleName'. + For an example, refer to the comment section in `ThunderCompilerGraphBenchmarking`. + + Reference: https://pytest-benchmark.readthedocs.io/en/latest/hooks.html#pytest_benchmark.hookspec.pytest_benchmark_group_stats + """ + prefix = "graph-by-graph:" + outcome = yield + if group_by.startswith(prefix): + group_by = group_by[len(prefix) :] + for bench in benchmarks: + if bench["params"] is None: + bench["params"] = {} + # The benchs with the same `params`` share the same dict + # We need to create a deepcopy of the original dictionary to add parameters specific to each graph. + else: + bench["params"] = bench["params"].copy() + if bench["param"] is None: + bench["param"] = "" + + name = bench["name"] + gid, module_name, ex = name.split("-")[-3:] + # Add the "GraphID", "SplitModuleName","executor" as params in benchmark + gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS + bench["params"].update({gid_key: gid, module_name_key: module_name, ex_key: ex}) + bench["param"] += f"-{gid}-{module_name}-{ex}" + + result = pytest_benchmark.plugin.pytest_benchmark_group_stats(config, benchmarks, group_by) + outcome.force_result(result) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index a667405a13..cae17adf65 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -4,6 +4,7 @@ from thunder import dtypes from thunder.dynamo import ThunderCompiler +from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking from thunder import last_traces from thunder.tests.bf16 import device_supports_bf16 from thunder.tests.framework import ( @@ -465,3 +466,47 @@ def func(x): actual_grad = torch.autograd.grad(actual, x, g) expected_grad = torch.autograd.grad(expected, x, g) torch.testing.assert_close(actual_grad, expected_grad) + + +# Sample command to run the benchmark using ThunderCompilerGraphBenchmarking +# pytest thunder/tests/test_dynamo.py -k test_ThunderCompilerGraphBenchmarking_groupby --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName' +# For more details, see :class:`thunder.dynamo.compiler_graph_benchmark.ThunderCompilerGraphBenchmarking` +# NOTE: The conftest.py file customizes the benchmark grouping behavior for ThunderCompilerGraphBenchmarking. +# It must be located in the same folder as the test file to ensure the configuration. +@requiresCUDA +def test_ThunderCompilerGraphBenchmarking_LlamaMLPBenchmark(benchmark): + backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"]) + from thunder.benchmarks import LlamaMLPBenchmark, Benchmark + + bench: Benchmark = LlamaMLPBenchmark( + config="Llama-2-7b-hf", + batchdims=(2,), + device="cuda:0", + requires_grad=True, + ) + + args, kwargs = bench.make_batch() + # Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in + # https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421 + fn = torch._dynamo.optimize(backend=backend)(bench.fn()) + fn(*args, **kwargs) + + +@requiresCUDA +def test_ThunderCompilerGraphBenchmarking_groupby(benchmark): + def f(x, y): + x = torch.sin(x) + if x.sum() > 0: + x = x.exp() + y = torch.sinc(x) + torch.cos(y) + return y + 1 + else: + y = y.exp() + x = torch.sinc(y) + torch.cos(x) + return x - 1 + + backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"]) + compiled = torch.compile(backend=backend)(f) + x = torch.ones(2, requires_grad=True).cuda() + y = torch.ones(2, requires_grad=True).cuda() + compiled(x, y)