From e617b114167e9d0efc06df249eea9cc6cb2c3a21 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 27 Nov 2024 17:04:10 +0100 Subject: [PATCH] ThunderFX: Add split reason and submodule structure in repro script (#1461) --- thunder/dynamo/compiler.py | 1 + thunder/dynamo/utils.py | 20 +++++++++++++++++--- thunder/tests/test_dynamo.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 9eb7dd851e..6168f863d0 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -121,6 +121,7 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes ): reproducer( cur_module, + subgraph_info, self.thunder_options, example_input, reproducer_folder, diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 4bab617cde..fda8f20c40 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -728,6 +728,7 @@ def thunder_options_to_str(thunder_options: dict) -> str: def reproducer( gm: torch.fx.GraphModule, + subgraph_info, thunder_options: dict, args: tuple[torch.Tensor | ExampleInputMetaData], folder: str | os.PathLike, @@ -742,6 +743,17 @@ def reproducer( readable = _readable(gm, "DynamoModule", print_output=False) has_cuda_args = any(hasattr(arg, "device") and arg.device.type == "cuda" for arg in args) thunder_options_str = thunder_options_to_str(thunder_options) + + # split reason + if subgraph_info.split_reasons: + num_submodules = len(subgraph_info.submodule_to_compiled_functions) + num_thunder_submodules = len(subgraph_info.thunder_compiled_fns) + split_reason_str = f"The original graph is split into {num_submodules} subgraphs, {num_thunder_submodules} of which are run by thunder.jit.\n" + split_reason_str += f"The structure of the split module:\n{subgraph_info.split_graph_module}\n" + split_reason_str += f"Split Reasons:\n" + for id, split_reason in enumerate(subgraph_info.split_reasons): + split_reason_str += f" Split Reason {id}:\n {split_reason.info}\n" + with open(folder / f"{graph_name}.py", "w") as f: comment_str = f'''""" Environment information get from `torch.utils.collect_env.get_pretty_env_info()`: @@ -750,10 +762,12 @@ def reproducer( Versions of Thunder related libraries: {thunder_pkgs} -The torch.fx.Graph: -{gm.graph} -""" ''' + if subgraph_info.split_reasons: + comment_str += f'{split_reason_str}"""\n' + del split_reason_str + else: + comment_str += '"""\n' if use_pytest_benchmark: comment_str += f"""# NOTE: This script requires `pytest-benchmark==4.0.0` to be installed. # To execute the script, run `pytest {graph_name}.py`""" diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index cc740ff408..6873e5c4c2 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -864,3 +864,38 @@ def forward(self, x): cmd = "pytest" if use_pytest_benchmark else "python" result1 = run([cmd, s1], capture_output=True, text=True) assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}" + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=(pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),), +) +def test_dynamo_reproducer_split(executor, device: str, dtype: dtypes.dtype, use_pytest_benchmark, tmp_path): + x = torch.ones(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + # torch.sinc has automatic fallback registered, + # so that operation will be given to inductor. + x = x.exp() + y = torch.sinc(x) + torch.cos(x) + y = y + torch.sinc(x) + return y + 1 + + cfunc = torch.compile(func, backend=backend) + actual = cfunc(x) + backend.save_reproducer_to_folder(tmp_path, use_pytest_benchmark) + + def check(file_name, cmd): + assert os.path.exists(file_name) + result = run([cmd, file_name], capture_output=True, text=True) + assert result.returncode == 0, f"Reproducer {file_name} failed with return code {result.returncode}" + + s1 = f"{tmp_path}/graph0_thunder_0.py" + s2 = f"{tmp_path}/graph0_thunder_2.py" + s3 = f"{tmp_path}/graph0_thunder_4.py" + cmd = "pytest" if use_pytest_benchmark else "python" + for fname in [s1, s2, s3]: + check(fname, cmd)