diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index c9449f9f87..bfe2cb376f 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -404,6 +404,9 @@ def reverse_transform_state_dict_for_submodule( ) -> dict[str, Any]: return state_dict + def __repr__(self) -> str: + return f"{self.__class__.__module__}.{self.__class__.__name__}()" + def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]: """computes a canonical ordering of proxies in the bound symbols based on the order of appearance diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index d61aa10812..9eb7dd851e 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -7,11 +7,13 @@ import torch from thunder.core.baseutils import run_once -from thunder.dynamo.utils import recompile_graph, remove_empty_autocast +from thunder.core.utils import safe_zip +from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType from thunder.dynamo.splitter import _splitter if TYPE_CHECKING: from thunder.dynamo.utils import SubgraphInfo + from os import PathLike @run_once @@ -83,3 +85,45 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args) self.subgraph_infos.append(subgraph_info) return split_module + + def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytest_benchmark: bool = False): + """ + Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`. + Each saved script is named as "graph[graph_id]_thunder_[module_id]", where: + + - `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder. + - `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`. + + Args: + reproducer_folder (str | PathLike): The folder where the reproducer code will be written. Can be specified as an absolute or relative path. + use_pytest_benchmark (str): Determines the type of script to create: + + - If use_pytest_benchmark=False: Creates a reproducer script. + - If use_pytest_benchmark=True: Creates a benchmark script to compare the reproducer's performance with other backends, including Torch eager, torch.compile, and torch.compile with `backend="eager"`. + """ + if not self.subgraph_infos: + raise TypeError(f"{self} doesn't seem to have been called yet.") + + for graph_idx, subgraph_info in enumerate(self.subgraph_infos): + thunder_module_names = [] + for node in subgraph_info.split_graph_module.graph.nodes: + target = node.target + if isinstance(target, str) and target.startswith("thunder_"): + thunder_module_names.append(target) + original_thunder_modules = ( + m + for m, compiled_m in subgraph_info.submodule_to_compiled_functions.items() + if compiled_m.compiler == CompilerType.THUNDER + ) + example_inputs = subgraph_info.thunder_compiled_fns_example_inputs + for cur_module, example_input, cur_name in safe_zip( + original_thunder_modules, example_inputs, thunder_module_names + ): + reproducer( + cur_module, + self.thunder_options, + example_input, + reproducer_folder, + f"graph{graph_idx}_{cur_name}", + use_pytest_benchmark, + ) diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py index ddb7f80e53..5dc6eecd37 100644 --- a/thunder/dynamo/compiler_graph_benchmark.py +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -103,19 +103,25 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args): if self.post_graph: compiled_fn = self.post_graph(compiled_fn, sample_args) - with record_peak_allocated_memory(self.bench): + # This guard ensures compatibility with CPU-only PyTorch builds. + if torch.cuda.is_available(): + with record_peak_allocated_memory(self.bench): + self.bench(compiled_fn, *sample_args) + else: self.bench(compiled_fn, *sample_args) # BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150) # Adds the graph number, split module name and executor suffix to the name string gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS - self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]" - assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info - assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info - # NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark. - # Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats. - self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = ( - self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD) - ) + self.bench.stats.name += f"-{gid_key}[{self.graph_idx}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]" + + if torch.cuda.is_available(): + assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info + assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info + # NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark. + # Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats. + self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = ( + self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD) + ) # when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error: # `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)` @@ -158,7 +164,7 @@ def has_checkpoint_node(g): cur_nodes = cur_module.graph.nodes # Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node placeholders = list(n for n in cur_nodes if n.op == "placeholder") - args = chain(*map(_get_example_inputs_from_placeholder, placeholders)) + args = list(map(_get_example_inputs_from_placeholder, placeholders)) # Runs the benchmark on the original module with the generated random inputs self.run_bench(compiled_functions_to_submodule[cur_module], target, *args) self.graph_idx += 1 diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index b128357b97..4b455f60b6 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING import copy +from functools import partial import torch from torch.fx.passes.split_module import split_module @@ -16,6 +17,7 @@ update_node_and_submodule, recompile_graph, checkpoint_converter, + _get_example_inputs_from_placeholder, ) if TYPE_CHECKING: @@ -124,8 +126,9 @@ def callback(node) -> int: return partition_cnt # There is a flip. Either from supported to unsupported or unsupported to supported. + if prev_value is not None: + partition_cnt += 1 # Bump the region cnt. prev_value = is_thunder_supported - partition_cnt += 1 # Bump the region cnt. if is_thunder_supported: supported_partitions.add(partition_cnt) @@ -142,11 +145,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: # Call compile on the split region/s. thunder_compiled_fns = [] + example_input_metadatas = [] submodule_to_compiled_fns = {} for node in split_gm.graph.nodes: node_name = node.name if is_thunder_supported_partition(node): graph_module = getattr(split_gm, node.name) + # Record the input tensor metadata of the current module based on the faketensor 'example_value' of the placeholder node + placeholders = list(n for n in graph_module.graph.nodes if n.op == "placeholder") + example_input_metadata = map( + partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders + ) + example_input_metadatas.append(list(example_input_metadata)) # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators checkpoint_converter(split_gm, graph_module) jit_fn = thunder_jit(graph_module) @@ -176,6 +186,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: original_split_gm, split_gm, thunder_compiled_fns, + example_input_metadatas, submodule_to_compiled_fns, split_reasons, ) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 668f2ef0bc..4bab617cde 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -6,8 +6,11 @@ import inspect import itertools import copy +from pathlib import Path import torch +from torch.nn.modules.module import _addindent +from torch._subclasses.fake_tensor import FakeTensor from thunder.torch.default_torch_ops import torch_auto_registered_ops from thunder.torch import _torch_to_thunder_function_map @@ -16,6 +19,9 @@ if TYPE_CHECKING: from thunder.core.symbol import Symbol + import os + from typing import Any + from collections.abc import Sequence auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) @@ -74,6 +80,26 @@ class SplitReason: exception: Exception | None = None +@dataclasses.dataclass(frozen=True) +class ExampleInputMetaData: + """ + Describes the metadata of a tensor, used to generate a random tensor with matching properties + """ + + requires_grad: bool + layout: torch.layout + device: str | torch.device + dtype: torch.dtype + shape: list[int] + storage_shape: list[int] + strides: list[int] + min_val: int | None = None + max_val: int | None = None + + def stride(self) -> list[int]: + return self.strides + + @dataclasses.dataclass(frozen=True) class SubgraphInfo: """A dataclass containing information about a subgraph. @@ -87,6 +113,8 @@ class SubgraphInfo: thunder_compiled_fns: List of thunder optimized callables. This could be :obj:`None` if there the graph module was not supported by thunder. Look at the :attr:`split_reasons` for further information. + thunder_compiled_fns_example_inputs: List containing metadata of sample inputs for `thunder_compiled_fns`. + These inputs are used to generate random test inputs in the reproducer script. submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function. This will be a dict with one pair in case the graph was not split. split_reasons: List of reasons explaining why the subgraph was split. @@ -97,13 +125,14 @@ class SubgraphInfo: original_split_graph_module: torch.fx.GraphModule | None split_graph_module: torch.fx.GraphModule | None thunder_compiled_fns: list[Callable] | None + thunder_compiled_fns_example_inputs: list[list[ExampleInputMetaData]] | None submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction] split_reasons: list | None = None -def _concrete_shape(x): +def _concrete_value(vals: torch.Size | Sequence): """ - Get the concrete shape for a FakeTensor if it has `torch.SymInt` in its shape. + Get the concrete value from the input `vals` if it contains `torch.SymInt`. """ def get_backed_value(s): @@ -112,7 +141,7 @@ def get_backed_value(s): # Value is already concrete. return s - return tuple(map(get_backed_value, x.shape)) + return tuple(map(get_backed_value, vals)) def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]: @@ -147,11 +176,12 @@ def make_tensor_proxy(arg_node): # Here, we only want to verify that thunder can run an operation. # So, it is ok to verify with concrete value. example_value = example_value.new_ones( - _concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + _concrete_value(example_value.shape), device=example_value.device, dtype=example_value.dtype ) elif isinstance(example_value, tuple): example_value = tuple( - e_v.new_ones(_concrete_shape(e_v), device=e_v.device, dtype=e_v.dtype) for e_v in example_value + e_v.new_ones(_concrete_value(e_v.shape), device=e_v.device, dtype=e_v.dtype) + for e_v in example_value ) else: # NOTE - This will be caught will be caught and be part of the SplitReason. @@ -424,43 +454,78 @@ def recompile_graph(gm: torch.fx.GraphModule): return gm.recompile() -def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]: +def _get_storage_shape(t: torch.Tensor): + shape = _concrete_value(t.shape) + if t.is_contiguous(): + return shape + strides = _concrete_value(t.stride()) + storage_size = sum(strides[i] * (shape[i] - 1) for i in range(len(shape))) + 1 + return (storage_size,) + + +def _get_example_input_tensor_metadata(t: torch.Tensor) -> ExampleInputMetaData: + min_val = None + max_val = None + if not isinstance(t, FakeTensor) and t.numel() != 0: + minmax: tuple[torch.Tensor, torch.Tensor] = torch.aminmax(t) + min_val = minmax[0].cpu().item() + max_val = minmax[1].cpu().item() + meta_ev = ExampleInputMetaData( + t.requires_grad, + t.layout, + t.device, + t.dtype, + _concrete_value(t.shape), + _get_storage_shape(t), + _concrete_value(t.stride()), + min_val, + max_val, + ) + return meta_ev + + +def _create_random_tensor_from_tensor_metadata(t: ExampleInputMetaData) -> torch.Tensor: from thunder.tests.make_tensor import make_tensor + return make_tensor(t.storage_shape, dtype=t.dtype, device=t.device, requires_grad=t.requires_grad).as_strided( + t.shape, t.stride() + ) + + +def _get_example_inputs_from_placeholder( + node: torch.fx.Node, only_metadata=False +) -> tuple[torch.Tensor | ExampleInputMetaData] | torch.Tensor | ExampleInputMetaData: + """Retrieves example input data for a given placeholder `torch.fx.Node`. + - When `only_metadata` is `False`: Generates and returns a random example tensor based on the node's expected shape and data type, etc. + - When `only_metadata` is `True`: Returns only the tensor's metadata (e.g., shape, data type) without generating an actual tensor. + """ check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError) # Prefers to use actual example value in GraphArg if available if "grapharg" in node.meta: - example_value = node.meta["grapharg"].example - if isinstance(example_value, torch.Tensor): - return (example_value.detach().clone().requires_grad_(example_value.requires_grad),) - - check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError) + ev = node.meta["grapharg"].example + if isinstance(ev, torch.Tensor): + if only_metadata: + return _get_example_input_tensor_metadata(ev) + return ev.detach().clone().requires_grad_(ev.requires_grad) + + if "example_value" not in node.meta: + return None example_value = node.meta["example_value"] if isinstance(example_value, torch.Tensor): - sz = _concrete_shape(example_value) - return ( - make_tensor( - sz, - dtype=example_value.dtype, - device=example_value.device, - requires_grad=example_value.requires_grad, - ).as_strided(sz, example_value.stride()), - ) + ev_metadata = _get_example_input_tensor_metadata(example_value) + if only_metadata: + return ev_metadata + return _create_random_tensor_from_tensor_metadata(ev_metadata) elif isinstance(example_value, tuple): - return tuple( - make_tensor( - _concrete_shape(e_v), - dtype=e_v.dtype, - device=e_v.device, - requires_grad=e_v.requires_grad, - ).as_strided(_concrete_shape(e_v), e_v.stride()) - for e_v in example_value - ) + ev_metadatas = tuple(_get_example_input_tensor_metadata(e_v) for e_v in example_value) + if only_metadata: + return ev_metadatas + return tuple(_create_random_tensor_from_tensor_metadata(ev_metadata) for ev_metadata in ev_metadatas) + elif isinstance(example_value, torch.types.py_sym_types): + return example_value.node.hint else: - raise TypeError( - "The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors." - ) + raise TypeError(f"Unsupported example_value type: {type(example_value)}") def _checkpoint_function_converter(gm: torch.fx.GraphModule): @@ -555,3 +620,179 @@ def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphM empty_autocast_removed_graph_module.graph.erase_node(node) return empty_autocast_removed_graph_module + + +def arg_like_tensor(arg: torch.Tensor | ExampleInputMetaData): + """Creates a new argument like the given tensor or tensor metadata""" + min_val = None + max_val = None + if isinstance(arg, torch.Tensor): + if arg.numel() != 0: + min_val, max_val = torch.aminmax(arg) + min_val = min_val.cpu().item() + max_val = max_val.cpu().item() + else: + min_val, max_val = arg.min_val, arg.max_val + storage_shape = _get_storage_shape(arg) if isinstance(arg, torch.Tensor) else arg.storage_shape + if min_val is not None and min_val == max_val: + meta = f"{storage_shape}, {min_val}, dtype={arg.dtype}, device='{arg.device}', requires_grad={arg.requires_grad}, layout={arg.layout}" + return f"torch.full({meta}).as_strided({arg.shape}, {arg.stride()})," + meta = f"{storage_shape}, dtype={arg.dtype}, device='{arg.device}', requires_grad={arg.requires_grad}," + meta = f"{meta} low={min_val}, high={max_val}," + return f"torch.testing.make_tensor({meta}).as_strided({arg.shape}, {arg.stride()})," + + +def arg_like(arg: Any): + """Creates a new argument that is similar to the given arg.""" + if isinstance(arg, (torch.Tensor, ExampleInputMetaData)): + return f"{arg_like_tensor(arg)}" + else: + # Assume it's a literal that we can just print directly. + return f"{arg}," + + +def _readable( + module: torch.fx.GraphModule, + module_name: str, + print_output: bool = False, + include_stride: bool = True, + include_device: bool = True, + colored: bool = False, +): + """Modified from `torch.fx.graph_module._print_readable` (https://github.com/pytorch/pytorch/blob/3192bdeea428f2bf3a95274ee59ea41c4f8e31e9/torch/fx/graph_module.py#L297). + This is basically print_readable but it sets verbose=False (torch hardcodes it to True).""" + graph = module.graph + assert graph is not None and isinstance( + graph, torch.fx.Graph + ), "print_readable must be used on a module with a graph" + + verbose_python_code = graph.python_code( + root_module="self", + verbose=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 2) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 2) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + +def get_env() -> tuple[str, str]: + """Retrieve detailed environment information using `torch.utils.collect_env.get_pretty_env_info()`. + Additionally, include the installed versions of Thunder and NvFuser (if available via pip). + """ + + from torch.utils.collect_env import run, get_pretty_env_info, get_pip_packages + + torch_env = get_pretty_env_info() + _, thunder_packages = get_pip_packages(run, {"lightning-thunder", "nvfuser"}) + return torch_env, thunder_packages + + +def thunder_options_to_str(thunder_options: dict) -> str: + from thunder import resolve_executors + + option_str = "" + for key, value in thunder_options.items(): + if key == "executors": + executors = resolve_executors(value) + option_str += f"{key}=[" + ",".join(f"thunder.extend.get_executor('{ex.name}')" for ex in executors) + "]" + else: + option_str += f"{key}={repr(value)}" + option_str += "," + return option_str + + +def reproducer( + gm: torch.fx.GraphModule, + thunder_options: dict, + args: tuple[torch.Tensor | ExampleInputMetaData], + folder: str | os.PathLike, + graph_name: str, + use_pytest_benchmark: bool = False, +): + folder = Path(folder) + folder.mkdir(exist_ok=True) + torch_env, thunder_pkgs = get_env() + # Ideally we'd use print_readable, but we want verbose=False and there's no + # way to set that with print_readable. + readable = _readable(gm, "DynamoModule", print_output=False) + has_cuda_args = any(hasattr(arg, "device") and arg.device.type == "cuda" for arg in args) + thunder_options_str = thunder_options_to_str(thunder_options) + with open(folder / f"{graph_name}.py", "w") as f: + comment_str = f'''""" +Environment information get from `torch.utils.collect_env.get_pretty_env_info()`: +{torch_env} + +Versions of Thunder related libraries: +{thunder_pkgs} + +The torch.fx.Graph: +{gm.graph} +""" +''' + if use_pytest_benchmark: + comment_str += f"""# NOTE: This script requires `pytest-benchmark==4.0.0` to be installed. +# To execute the script, run `pytest {graph_name}.py`""" + import_str = f"from functools import partial\n\nimport torch\nimport thunder\n" + if has_cuda_args: + import_str += "from thunder.transforms.cudagraph import CUDAGraphTransform\n" + import_str += "from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform\n" + if use_pytest_benchmark: + code_str = f"def test_{graph_name}(benchmark):\n{readable}\n" + else: + code_str = f"def test_{graph_name}():\n{readable}\n" + + if any(arg is None for arg in args): + code_str += f"# Warning: The inputs that cannot be inferred are set to None, requiring the user to manually give inputs according to the code\n" + input_str = f"""inputs = [\n{chr(10).join(arg_like(a) for a in args)}\n""" + code_str += f"{_addindent(input_str, 4)}\n]\n" + + if not use_pytest_benchmark: + code_str += f"compiled = thunder.jit(DynamoModule(), {thunder_options_str})\n" + code_str += "compiled(*inputs)" + else: + code_str += "from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking\n" + code_str = f"""{code_str} +bench_executors_dict = {{}} +bench_executors_dict["thunder"]=partial(thunder.jit, {thunder_options_str}) +bench_executors_dict["torch.compile"]=torch.compile +bench_executors_dict["dynamo_eager"]=partial(torch.compile, backend="eager") +bench_executors_dict["eager"]=None +""" + if has_cuda_args: + code_str = f"""{code_str}bench_executors_dict["thunder_cugraph"]=partial(thunder.jit, transform=CUDAGraphTransform())\n""" + code_str += f""" +backend = ThunderCompilerGraphBenchmarking(benchmark, executors=bench_executors_dict) +compiled = torch.compile(backend=backend)(DynamoModule()) +compiled(*inputs) +""" + print(comment_str, file=f) + print(import_str, file=f) + print(_addindent(code_str, 4), file=f) + + if not use_pytest_benchmark: + print(f"\ntest_{graph_name}()", file=f) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 65e6603f54..cc740ff408 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1,6 +1,8 @@ import pytest import warnings import itertools +import os +from subprocess import run import torch import torch.fx import torch.nn as nn @@ -558,7 +560,7 @@ def f(x): all_nodes = itertools.chain( backend.subgraph_infos[0].split_graph_module.graph.nodes, - backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes, + backend.subgraph_infos[0].split_graph_module.thunder_0.graph.nodes, ) assert all(node.target not in autocast_ops for node in all_nodes) @@ -575,7 +577,7 @@ def f(x): backend = _call_thunder_backend(f, (x,)) all_nodes = itertools.chain( backend.subgraph_infos[0].split_graph_module.graph.nodes, - backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes, + backend.subgraph_infos[0].split_graph_module.thunder_0.graph.nodes, ) assert sum(node.target in autocast_ops for node in all_nodes) == 2 @@ -782,3 +784,83 @@ def find_target_module(model, target_module_name): for n in submodule.graph.nodes: if n.op == "call_function": assert isinstance(n.target, Symbol) + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=(pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),), +) +def test_dynamo_reproducer_2graph(executor, device: str, dtype: dtypes.dtype, use_pytest_benchmark, tmp_path): + from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform + from thunder import nvfuser_executor + from thunder.transforms.cudagraph import CUDAGraphTransform + + if device.startswith("cuda"): + backend = ThunderCompiler( + transforms=[ + NvtxProfileTransform(), + CUDAGraphTransform(), + ], + executors=[nvfuser_executor], + cache="constant values", + langctx=None, + record_history=False, + ) + else: + backend = ThunderCompiler(executors=None) + # Test non-contiguous input tensor + x = make_tensor((4, 4), low=3, high=10, dtype=torch.int64, device=device, noncontiguous=True) + + @torch.compile(backend=backend) + def func(x): + x = torch.sin(x) + if x.sum() > 0: + return x + 1 + else: + return x - 1 + + out = func(x) + backend.save_reproducer_to_folder(tmp_path, use_pytest_benchmark=use_pytest_benchmark) + + s1 = f"{tmp_path}/graph0_thunder_0.py" + s2 = f"{tmp_path}/graph1_thunder_0.py" + assert os.path.exists(s1) + assert os.path.exists(s2) + cmd = "pytest" if use_pytest_benchmark else "python" + result1 = run([cmd, s1], capture_output=True, text=True) + result2 = run([cmd, s2], capture_output=True, text=True) + + assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}" + assert result2.returncode == 0, f"Reproducer {s2} failed with return code {result2.returncode}" + + +@requiresCUDA +@pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")) +def test_dynamo_reproducer_submodules(use_pytest_benchmark, tmp_path): + from thunder.tests.distributed.helper import ToyModel + import torch.nn as nn + + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub_mod = ToyModel() + self.seq = nn.Sequential(self.sub_mod, nn.ReLU()) + + def forward(self, x): + x = torch.sin(x) + x = self.seq(x) + return x + + x = torch.randn(1, ToyModel.N_IN, device="cuda", requires_grad=True) + model = SimpleModel().cuda() + backend = ThunderCompiler() + jf = torch.compile(backend=backend)(model) + out = jf(x) + backend.save_reproducer_to_folder(tmp_path, use_pytest_benchmark=use_pytest_benchmark) + + s1 = f"{tmp_path}/graph0_thunder_0.py" + assert os.path.exists(s1) + cmd = "pytest" if use_pytest_benchmark else "python" + result1 = run([cmd, s1], capture_output=True, text=True) + assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}"