Skip to content

support graph-by-graph benchmarking for PyTorch native checkpointing #1437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions thunder/dynamo/compiler_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import chain
from pytest_benchmark.fixture import BenchmarkFixture
from typing import TYPE_CHECKING
from looseversion import LooseVersion

import torch
from thunder.dynamo import ThunderCompiler
Expand Down Expand Up @@ -124,6 +125,23 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args):

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
split_module = super().__call__(gm, sample_args)

def has_checkpoint_node(g):
if g.find_nodes(op="call_function", target=torch.ops.higher_order.tag_activation_checkpoint):
return True
for n in g.nodes:
if n.op == "call_module" and has_checkpoint_node(getattr(g.owning_module, n.target).graph):
return True
return False

if LooseVersion(torch.__version__) < LooseVersion("2.6.0"):
# NOTE: PyTorch 2.6 changes the structure of GraphModule when using activation checkpointing.
# It's hard to retrieve the example input tensor for the GraphModule contains checkpoint operator before PyTorch 2.6
if has_checkpoint_node(split_module.graph):
raise RuntimeError(
"The benchmarking of the Torch activation checkpointing is only supported with PyTorch version 2.6 or later."
)

compiled_functions_to_submodule = {
v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items()
}
Expand Down
14 changes: 11 additions & 3 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import copy

import torch
from torch.fx.passes.split_module import split_module
Expand Down Expand Up @@ -131,9 +132,10 @@ def callback(node) -> int:
return partition_cnt

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
split_gm: torch.fx.GraphModule = split_module(
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)
split_gm = copy.deepcopy(original_split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions
Expand All @@ -142,6 +144,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
thunder_compiled_fns = []
submodule_to_compiled_fns = {}
for node in split_gm.graph.nodes:
node_name = node.name
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
Expand All @@ -150,13 +153,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn)
thunder_compiled_fns.append(jit_fn)
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER)
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
jit_fn, CompilerType.THUNDER
)
elif node.name.startswith("submod"): # For inductor
graph_module = getattr(split_gm, node.name)
jit_fn = torch_inductor(graph_module)
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR)
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
jit_fn, CompilerType.TORCH_INDUCTOR
)
else:
# Everything else is a glue code to call and pass outputs between the other partitions.
pass
Expand All @@ -166,6 +173,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:

return split_gm, SubgraphInfo(
gm,
original_split_gm,
split_gm,
thunder_compiled_fns,
submodule_to_compiled_fns,
Expand Down
20 changes: 11 additions & 9 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,21 @@ class SubgraphInfo:

Attributes:
original_graph_module: The original graph module.
split_graph_module: The graph module for the split subgraph.
original_split_graph_module: The original split graph module before any transformations are applied.
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols,
and before any submodules are compiled by Thunder.
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
thunder_compiled_fns: List of thunder optimized callables.
This could be :obj:`None` if there the graph module was not supported by thunder.
Look at the :attr:`split_reasons` for further information.
submodule_to_compiled_functions: Dict from subgraph to compiled function.
submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function.
This will be a dict with one pair in case the graph was not split.
split_reasons: List of reasons explaining why the subgraph was split.
Present only if there are was a split.
"""

original_graph_module: torch.fx.GraphModule
original_split_graph_module: torch.fx.GraphModule | None
split_graph_module: torch.fx.GraphModule | None
thunder_compiled_fns: list[Callable] | None
submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction]
Expand Down Expand Up @@ -466,8 +470,7 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule):
Args:
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace.
"""
new_graph = copy.deepcopy(gm.graph)
for n in new_graph.nodes:
for n in gm.graph.nodes:
# replace the torch operator in "call_function" node
if n.op == "call_function":
assert isinstance(n.target, Callable)
Expand All @@ -476,19 +479,18 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule):
check(
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder"
)
with new_graph.inserting_before(n):
thunder_node = new_graph.call_function(
with gm.graph.inserting_before(n):
thunder_node = gm.graph.call_function(
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs
)
n.replace_all_uses_with(thunder_node)
new_graph.erase_node(n)
gm.graph.erase_node(n)
else:
if n.op == "call_module":
raise RuntimeError(
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs"
)
new_graph.lint()
gm.graph = new_graph
gm.graph.lint()
recompile_graph(gm)


Expand Down
34 changes: 34 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.fx
import torch.nn as nn
import torch.nn.functional as F
from looseversion import LooseVersion

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
Expand Down Expand Up @@ -445,6 +446,10 @@ def func(x):
IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
),
),
)
@requiresCUDA
Expand Down Expand Up @@ -639,6 +644,35 @@ def f(x):
compiled(x)


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="The checkpoint function becomes a submodule of the module containing `tag_activation_checkpoint` in PyTorch 2.6.0.",
)
@requiresCUDA
def test_ThunderCompilerGraphBenchmarking_checkpoint(benchmark):
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20)

def forward(self, x):
x = torch.utils.checkpoint.checkpoint(self.layer1, x)
x = F.relu(x)
return x

x = torch.randn(5, 10).cuda().requires_grad_()
model = SimpleModel().cuda().train()

exe_backend = ThunderCompiler()
backend = ThunderCompilerGraphBenchmarking(
benchmark, executors={"inductor": torch.compile, "thunderfx": torch.compile(backend=exe_backend)}
)
# Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
jf = torch._dynamo.optimize(backend=backend)(model)
out = jf(x)


@requiresCUDA
@pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning")
def test_checkpoint_converter():
Expand Down
Loading