Skip to content

Commit 60f3ee1

Browse files
authored
support graph-by-graph benchmarking for PyTorch native checkpointing (#1437)
1 parent f206afa commit 60f3ee1

File tree

4 files changed

+74
-12
lines changed

4 files changed

+74
-12
lines changed

thunder/dynamo/compiler_graph_benchmark.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import chain
33
from pytest_benchmark.fixture import BenchmarkFixture
44
from typing import TYPE_CHECKING
5+
from looseversion import LooseVersion
56

67
import torch
78
from thunder.dynamo import ThunderCompiler
@@ -124,6 +125,23 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args):
124125

125126
def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
126127
split_module = super().__call__(gm, sample_args)
128+
129+
def has_checkpoint_node(g):
130+
if g.find_nodes(op="call_function", target=torch.ops.higher_order.tag_activation_checkpoint):
131+
return True
132+
for n in g.nodes:
133+
if n.op == "call_module" and has_checkpoint_node(getattr(g.owning_module, n.target).graph):
134+
return True
135+
return False
136+
137+
if LooseVersion(torch.__version__) < LooseVersion("2.6.0"):
138+
# NOTE: PyTorch 2.6 changes the structure of GraphModule when using activation checkpointing.
139+
# It's hard to retrieve the example input tensor for the GraphModule contains checkpoint operator before PyTorch 2.6
140+
if has_checkpoint_node(split_module.graph):
141+
raise RuntimeError(
142+
"The benchmarking of the Torch activation checkpointing is only supported with PyTorch version 2.6 or later."
143+
)
144+
127145
compiled_functions_to_submodule = {
128146
v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items()
129147
}

thunder/dynamo/splitter.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from typing import TYPE_CHECKING
3+
import copy
34

45
import torch
56
from torch.fx.passes.split_module import split_module
@@ -131,9 +132,10 @@ def callback(node) -> int:
131132
return partition_cnt
132133

133134
# `split_module` iterates over nodes and determines the partition to place them based on the callback.
134-
split_gm: torch.fx.GraphModule = split_module(
135+
original_split_gm: torch.fx.GraphModule = split_module(
135136
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
136137
)
138+
split_gm = copy.deepcopy(original_split_gm)
137139

138140
def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
139141
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions
@@ -142,6 +144,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
142144
thunder_compiled_fns = []
143145
submodule_to_compiled_fns = {}
144146
for node in split_gm.graph.nodes:
147+
node_name = node.name
145148
if is_thunder_supported_partition(node):
146149
graph_module = getattr(split_gm, node.name)
147150
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
@@ -150,13 +153,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
150153
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
151154
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn)
152155
thunder_compiled_fns.append(jit_fn)
153-
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER)
156+
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
157+
jit_fn, CompilerType.THUNDER
158+
)
154159
elif node.name.startswith("submod"): # For inductor
155160
graph_module = getattr(split_gm, node.name)
156161
jit_fn = torch_inductor(graph_module)
157162
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
158163
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
159-
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR)
164+
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
165+
jit_fn, CompilerType.TORCH_INDUCTOR
166+
)
160167
else:
161168
# Everything else is a glue code to call and pass outputs between the other partitions.
162169
pass
@@ -166,6 +173,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
166173

167174
return split_gm, SubgraphInfo(
168175
gm,
176+
original_split_gm,
169177
split_gm,
170178
thunder_compiled_fns,
171179
submodule_to_compiled_fns,

thunder/dynamo/utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,21 @@ class SubgraphInfo:
8080
8181
Attributes:
8282
original_graph_module: The original graph module.
83-
split_graph_module: The graph module for the split subgraph.
83+
original_split_graph_module: The original split graph module before any transformations are applied.
84+
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols,
85+
and before any submodules are compiled by Thunder.
86+
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
8487
thunder_compiled_fns: List of thunder optimized callables.
8588
This could be :obj:`None` if there the graph module was not supported by thunder.
8689
Look at the :attr:`split_reasons` for further information.
87-
submodule_to_compiled_functions: Dict from subgraph to compiled function.
90+
submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function.
8891
This will be a dict with one pair in case the graph was not split.
8992
split_reasons: List of reasons explaining why the subgraph was split.
9093
Present only if there are was a split.
9194
"""
9295

9396
original_graph_module: torch.fx.GraphModule
97+
original_split_graph_module: torch.fx.GraphModule | None
9498
split_graph_module: torch.fx.GraphModule | None
9599
thunder_compiled_fns: list[Callable] | None
96100
submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction]
@@ -466,8 +470,7 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule):
466470
Args:
467471
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace.
468472
"""
469-
new_graph = copy.deepcopy(gm.graph)
470-
for n in new_graph.nodes:
473+
for n in gm.graph.nodes:
471474
# replace the torch operator in "call_function" node
472475
if n.op == "call_function":
473476
assert isinstance(n.target, Callable)
@@ -476,19 +479,18 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule):
476479
check(
477480
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder"
478481
)
479-
with new_graph.inserting_before(n):
480-
thunder_node = new_graph.call_function(
482+
with gm.graph.inserting_before(n):
483+
thunder_node = gm.graph.call_function(
481484
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs
482485
)
483486
n.replace_all_uses_with(thunder_node)
484-
new_graph.erase_node(n)
487+
gm.graph.erase_node(n)
485488
else:
486489
if n.op == "call_module":
487490
raise RuntimeError(
488491
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs"
489492
)
490-
new_graph.lint()
491-
gm.graph = new_graph
493+
gm.graph.lint()
492494
recompile_graph(gm)
493495

494496

thunder/tests/test_dynamo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.fx
66
import torch.nn as nn
77
import torch.nn.functional as F
8+
from looseversion import LooseVersion
89

910
from thunder import dtypes
1011
from thunder.dynamo import ThunderCompiler
@@ -445,6 +446,10 @@ def func(x):
445446
IS_WINDOWS,
446447
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
447448
),
449+
pytest.mark.skipif(
450+
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
451+
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
452+
),
448453
),
449454
)
450455
@requiresCUDA
@@ -639,6 +644,35 @@ def f(x):
639644
compiled(x)
640645

641646

647+
@pytest.mark.skipif(
648+
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
649+
reason="The checkpoint function becomes a submodule of the module containing `tag_activation_checkpoint` in PyTorch 2.6.0.",
650+
)
651+
@requiresCUDA
652+
def test_ThunderCompilerGraphBenchmarking_checkpoint(benchmark):
653+
class SimpleModel(nn.Module):
654+
def __init__(self):
655+
super().__init__()
656+
self.layer1 = nn.Linear(10, 20)
657+
658+
def forward(self, x):
659+
x = torch.utils.checkpoint.checkpoint(self.layer1, x)
660+
x = F.relu(x)
661+
return x
662+
663+
x = torch.randn(5, 10).cuda().requires_grad_()
664+
model = SimpleModel().cuda().train()
665+
666+
exe_backend = ThunderCompiler()
667+
backend = ThunderCompilerGraphBenchmarking(
668+
benchmark, executors={"inductor": torch.compile, "thunderfx": torch.compile(backend=exe_backend)}
669+
)
670+
# Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in
671+
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
672+
jf = torch._dynamo.optimize(backend=backend)(model)
673+
out = jf(x)
674+
675+
642676
@requiresCUDA
643677
@pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning")
644678
def test_checkpoint_converter():

0 commit comments

Comments
 (0)