Skip to content

Commit

Permalink
ThunderFX: Add split reason and submodule structure in repro script (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Dec 4, 2024
1 parent 5f370f0 commit b243842
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytes
):
reproducer(
cur_module,
subgraph_info,
self.thunder_options,
example_input,
reproducer_folder,
Expand Down
20 changes: 17 additions & 3 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ def thunder_options_to_str(thunder_options: dict) -> str:

def reproducer(
gm: torch.fx.GraphModule,
subgraph_info: SubgraphInfo,
thunder_options: dict,
args: tuple[torch.Tensor | ExampleInputMetaData],
folder: str | os.PathLike,
Expand All @@ -742,6 +743,20 @@ def reproducer(
readable = _readable(gm, "DynamoModule", print_output=False)
has_cuda_args = any(hasattr(arg, "device") and arg.device.type == "cuda" for arg in args)
thunder_options_str = thunder_options_to_str(thunder_options)

# split reason
split_reason_str = "Split Information:\n"
if subgraph_info.split_reasons:
num_submodules = len(subgraph_info.submodule_to_compiled_functions)
num_thunder_submodules = len(subgraph_info.thunder_compiled_fns)
split_reason_str += f"The original graph is split into {num_submodules} subgraphs, {num_thunder_submodules} of which are run by Thunder.\n"
split_reason_str += f"The structure of the split module:\n{subgraph_info.split_graph_module}\n"
split_reason_str += f"Split Reasons:\n"
for id, split_reason in enumerate(subgraph_info.split_reasons):
split_reason_str += f" Split Reason {id}:\n {split_reason.info}\n"
else:
split_reason_str += "The original graph is not split, and is entirely run by Thunder.\n"

with open(folder / f"{graph_name}.py", "w") as f:
comment_str = f'''"""
Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:
Expand All @@ -750,10 +765,9 @@ def reproducer(
Versions of Thunder related libraries:
{thunder_pkgs}
The torch.fx.Graph:
{gm.graph}
"""
'''
comment_str += f'{split_reason_str}"""\n'
del split_reason_str
if use_pytest_benchmark:
comment_str += f"""# NOTE: This script requires `pytest-benchmark==4.0.0` to be installed.
# To execute the script, run `pytest {graph_name}.py`"""
Expand Down
35 changes: 35 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,38 @@ def forward(self, x):

# No assertion error
copy_gm = copy.deepcopy(original_split_gm)


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
decorators=(pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),),
)
def test_dynamo_reproducer_split(executor, device: str, dtype: dtypes.dtype, use_pytest_benchmark, tmp_path):
x = torch.ones(2, 2, device=device, dtype=dtype, requires_grad=True)

backend = ThunderCompiler()

def func(x):
# torch.sinc has automatic fallback registered,
# so that operation will be given to inductor.
x = x.exp()
y = torch.sinc(x) + torch.cos(x)
y = y + torch.sinc(x)
return y + 1

cfunc = torch.compile(func, backend=backend)
actual = cfunc(x)
backend.save_reproducer_to_folder(tmp_path, use_pytest_benchmark)

def check(file_name, cmd):
assert os.path.exists(file_name)
result = run([cmd, file_name], capture_output=True, text=True)
assert result.returncode == 0, f"Reproducer {file_name} failed with return code {result.returncode}"

s1 = f"{tmp_path}/graph0_thunder_0.py"
s2 = f"{tmp_path}/graph0_thunder_2.py"
s3 = f"{tmp_path}/graph0_thunder_4.py"
cmd = "pytest" if use_pytest_benchmark else "python"
for fname in [s1, s2, s3]:
check(fname, cmd)

0 comments on commit b243842

Please sign in to comment.