Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: Save the reproducer script into files #1380

Merged
merged 28 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8628bf4
copy Tom's code and use it in ThunderCompiler (#1082)
kiya00 Oct 23, 2024
de955c4
Add env information
kiya00 Oct 24, 2024
5e090b1
record the input information in SubgraphInfo and save reproducer afte…
kiya00 Oct 28, 2024
58c1b70
add comments
kiya00 Oct 29, 2024
d9a07a8
fix test
kiya00 Oct 31, 2024
d94ad53
fix bug: _get_example_inputs_from_placeholder
kiya00 Oct 31, 2024
414c270
Add comments
kiya00 Oct 31, 2024
c0a7b7c
support torch.types.py_sym_types
kiya00 Nov 11, 2024
d685fde
Add torch.fx.Graph in comments
kiya00 Nov 11, 2024
ffa5d22
fix rebase conflict
kiya00 Nov 11, 2024
df22f08
fix test
kiya00 Nov 12, 2024
a4f4567
Use the original thunder options as default
kiya00 Nov 14, 2024
4abc092
fix test: no example_value in node.meta
kiya00 Nov 14, 2024
43f8b16
use torch.testing.make_tensor
kiya00 Nov 14, 2024
4423dce
fix
kiya00 Nov 14, 2024
8ac866e
fix test
kiya00 Nov 14, 2024
67f65cf
use torch.full instead of make_tensor when low==high, add save_dynamo…
kiya00 Nov 15, 2024
a257896
fix: symInt in stride
kiya00 Nov 15, 2024
db6f040
fix
kiya00 Nov 15, 2024
d807902
modify repro script
kiya00 Nov 19, 2024
34c0c43
updata test
kiya00 Nov 19, 2024
e0e0fe7
fix: collect peak memory only when cuda is available
kiya00 Nov 19, 2024
4554dcf
use the original GraphModule
kiya00 Nov 19, 2024
d457c19
fix: aminmax throws error if input is empty tensor
kiya00 Nov 19, 2024
83c231e
follow comments
kiya00 Nov 20, 2024
e55623d
follow comments
kiya00 Nov 26, 2024
d302b63
Merge branch 'main' into dump_reproducer
kiya00 Nov 26, 2024
1752509
fix tests, splitter indexing from 0
kiya00 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ def reverse_transform_state_dict_for_submodule(
) -> dict[str, Any]:
return state_dict

def __repr__(self) -> str:
return f"{self.__class__.__module__}.{self.__class__.__name__}()"

IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved

def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
"""computes a canonical ordering of proxies in the bound symbols based on the order of appearance
Expand Down
46 changes: 45 additions & 1 deletion thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import torch

from thunder.core.baseutils import run_once
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast
from thunder.core.utils import safe_zip
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType
from thunder.dynamo.splitter import _splitter

if TYPE_CHECKING:
from thunder.dynamo.utils import SubgraphInfo
from os import PathLike


@run_once
Expand Down Expand Up @@ -83,3 +85,45 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args)
self.subgraph_infos.append(subgraph_info)
return split_module

def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytest_benchmark: bool = False):
"""
Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`.
Each saved script is named as "graph[graph_id]_thunder_[module_id]", where:

- `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder.
- `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`.

Args:
reproducer_folder (str | PathLike): The folder where the reproducer code will be written. Can be specified as an absolute or relative path.
use_pytest_benchmark (str): Determines the type of script to create:

- If use_pytest_benchmark=False: Creates a reproducer script.
- If use_pytest_benchmark=True: Creates a benchmark script to compare the reproducer's performance with other backends, including Torch eager, torch.compile, and torch.compile with `backend="eager"`.
"""
if not self.subgraph_infos:
raise TypeError(f"{self} doesn't seem to have been called yet.")

for graph_idx, subgraph_info in enumerate(self.subgraph_infos):
thunder_module_names = []
for node in subgraph_info.split_graph_module.graph.nodes:
target = node.target
if isinstance(target, str) and target.startswith("thunder_"):
thunder_module_names.append(target)
original_thunder_modules = (
m
for m, compiled_m in subgraph_info.submodule_to_compiled_functions.items()
if compiled_m.compiler == CompilerType.THUNDER
)
example_inputs = subgraph_info.thunder_compiled_fns_example_inputs
for cur_module, example_input, cur_name in safe_zip(
original_thunder_modules, example_inputs, thunder_module_names
):
reproducer(
cur_module,
self.thunder_options,
example_input,
reproducer_folder,
f"graph{graph_idx}_{cur_name}",
use_pytest_benchmark,
)
26 changes: 16 additions & 10 deletions thunder/dynamo/compiler_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,25 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args):
if self.post_graph:
compiled_fn = self.post_graph(compiled_fn, sample_args)

with record_peak_allocated_memory(self.bench):
# This guard ensures compatibility with CPU-only PyTorch builds.
if torch.cuda.is_available():
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
with record_peak_allocated_memory(self.bench):
self.bench(compiled_fn, *sample_args)
else:
self.bench(compiled_fn, *sample_args)
# BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150)
# Adds the graph number, split module name and executor suffix to the name string
gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]"
assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info
assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info
# NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark.
# Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats.
self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = (
self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD)
)
self.bench.stats.name += f"-{gid_key}[{self.graph_idx}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]"

if torch.cuda.is_available():
assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info
assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info
# NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark.
# Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats.
self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = (
self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD)
)

# when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error:
# `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)`
Expand Down Expand Up @@ -158,7 +164,7 @@ def has_checkpoint_node(g):
cur_nodes = cur_module.graph.nodes
# Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = chain(*map(_get_example_inputs_from_placeholder, placeholders))
args = list(map(_get_example_inputs_from_placeholder, placeholders))
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
# Runs the benchmark on the original module with the generated random inputs
self.run_bench(compiled_functions_to_submodule[cur_module], target, *args)
self.graph_idx += 1
Expand Down
13 changes: 12 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import copy
from functools import partial

import torch
from torch.fx.passes.split_module import split_module
Expand All @@ -16,6 +17,7 @@
update_node_and_submodule,
recompile_graph,
checkpoint_converter,
_get_example_inputs_from_placeholder,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,8 +126,9 @@ def callback(node) -> int:
return partition_cnt

# There is a flip. Either from supported to unsupported or unsupported to supported.
if prev_value is not None:
partition_cnt += 1 # Bump the region cnt.
prev_value = is_thunder_supported
partition_cnt += 1 # Bump the region cnt.

if is_thunder_supported:
supported_partitions.add(partition_cnt)
Expand All @@ -142,11 +145,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:

# Call compile on the split region/s.
thunder_compiled_fns = []
example_input_metadatas = []
submodule_to_compiled_fns = {}
for node in split_gm.graph.nodes:
node_name = node.name
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
# Record the input tensor metadata of the current module based on the faketensor 'example_value' of the placeholder node
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
placeholders = list(n for n in graph_module.graph.nodes if n.op == "placeholder")
example_input_metadata = map(
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
)
example_input_metadatas.append(list(example_input_metadata))
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
checkpoint_converter(split_gm, graph_module)
jit_fn = thunder_jit(graph_module)
Expand Down Expand Up @@ -176,6 +186,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
original_split_gm,
split_gm,
thunder_compiled_fns,
example_input_metadatas,
submodule_to_compiled_fns,
split_reasons,
)
Loading
Loading