Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A converter for FXGraph with Torch calls -> FXGraph with Thunder calls #1261

Merged
merged 25 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e0d394d
Add fxgraph with torch function to fxgraph with thunder symbol converter
kiya00 Oct 4, 2024
b554a4f
fix: torch.ops.higher_order.tag_activation_checkpoint cause the graph…
kiya00 Oct 4, 2024
b08fe09
add test
kiya00 Oct 4, 2024
88c3669
fix
kiya00 Oct 7, 2024
30af37c
move the converter into dynamo utils
kiya00 Oct 9, 2024
93d4492
make the converter work with splitter
kiya00 Oct 11, 2024
a6833ca
Keep the original split_module since the converter changes the checkp…
kiya00 Oct 11, 2024
84df6b6
mv test to test_dynamo.py
kiya00 Oct 11, 2024
184abfe
clean up
kiya00 Oct 11, 2024
84257dc
add test
kiya00 Oct 11, 2024
6d8d8a1
follow comments
kiya00 Oct 14, 2024
6305828
rm the original_split_module
kiya00 Oct 15, 2024
d0347b8
rm original split module graph
kiya00 Oct 15, 2024
951dd8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
8cc83e6
ignore futurewarning
kiya00 Oct 17, 2024
ecac46f
Apply suggestions from code review
kiya00 Oct 18, 2024
e826faa
use pytest.mark.filterwarnings
kiya00 Oct 18, 2024
4ff8580
fix rebase functional-autograd-checkpoint
kiya00 Oct 18, 2024
856141c
follow comments
kiya00 Oct 18, 2024
5408478
Apply suggestions from code review
kiya00 Oct 29, 2024
4f453cb
Merge branch 'main' into basedon-functional-autograd-checkpoint
kiya00 Oct 29, 2024
f417269
fix CI: torch nightly has changed the structure of split_gm
kiya00 Oct 29, 2024
f25cd0d
Merge branch 'main' into basedon-functional-autograd-checkpoint
kiya00 Oct 30, 2024
1371079
Merge branch 'main' into basedon-functional-autograd-checkpoint
kiya00 Oct 31, 2024
2489292
Merge branch 'main' into basedon-functional-autograd-checkpoint
IvanYashchuk Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_nodes_in_unsupported_ctx_regions,
update_node_and_submodule,
recompile_graph,
checkpoint_converter,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -143,6 +144,8 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
for node in split_gm.graph.nodes:
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
# Replace the torch operators within the function called by activation checkpoint with the corresponding Thunder symbols
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_converter(split_gm, graph_module)
jit_fn = thunder_jit(graph_module)
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn)
Expand Down
84 changes: 83 additions & 1 deletion thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import dataclasses
import inspect
import itertools
import warnings
import copy

import torch

from thunder.torch.default_torch_ops import torch_auto_registered_ops
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.langctx import torchctx
from thunder.core.utils import check
from thunder.core.pytree import tree_map
kiya00 marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -259,6 +260,27 @@ def is_no_grad_ctx_exit(node):
return nodes_in_unsupported_ctx_regions


def is_graphmodule_supported_by_thunder(gm):
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
for node in gm.graph.nodes:
if node.op in (
"placeholder",
"get_attr",
"output",
):
continue
if node in nodes_in_unsupported_ctx_regions:
split_reason = SplitReason(
SplitReasonType.UNSUPPORTED_NODE,
info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.",
)
return False, split_reason
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if not is_thunder_supported:
return False, split_reason
return True, None


def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
"""
Determine whether thunder can execute the operation described by this node.
Expand Down Expand Up @@ -306,6 +328,14 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
)
return False, split_reason

# If the operation is higher order function for checkpointing, check whether the submodule is supported by Thunder
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
if target is torch.ops.higher_order.tag_activation_checkpoint:
m = node.graph.owning_module
assert hasattr(m, node.args[0].name)
checkpointed_fn = getattr(m, node.args[0].name)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(checkpointed_fn)
return is_module_supported, split_reason

# If thunder has a mapping for this operation, try executing the meta function and see.
# We have a symbol for `torch.where`, but we don't support one overload of it.
# So, we try and execute the meta to get a real signal.
Expand Down Expand Up @@ -418,3 +448,55 @@ def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]:
raise TypeError(
"The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors."
)


def _checkpoint_function_converter(gm: torch.fx.GraphModule):
"""
Replace the Torch operators in the GraphModule called by activation checkpoint operator with the corresponding Thunder symbols in place
Args:
gm: The GraphModule of the checkpointed function, which is modified in place
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
"""
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
new_graph = copy.deepcopy(gm.graph)
for n in new_graph.nodes:
# replace the torch operator in "call_function" node
if n.op == "call_function":
assert isinstance(n.target, Callable)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
if n.target.__module__ in ("_operator", "builtins"):
continue
check(
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder"
)
with new_graph.inserting_before(n):
thunder_node = new_graph.call_function(
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs
)
n.replace_all_uses_with(thunder_node)
new_graph.erase_node(n)
else:
if n.op == "call_module":
raise RuntimeError(
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs"
)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
new_graph.lint()
gm.graph = new_graph
recompile_graph(gm)


def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule):
"""
Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule.

Args:
gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function.
sub_gm: the GraphModule containing the checkpoint operator
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved

Note:
The GraphModule of the checkpointed function is updated inplace
"""
for n in sub_gm.graph.nodes:
if n.op == "call_function":
if n.target in (torch.ops.higher_order.tag_activation_checkpoint,):
name = n.args[0].name
assert hasattr(gm, name)
function_module = getattr(gm, name)
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
_checkpoint_function_converter(function_module)
86 changes: 86 additions & 0 deletions thunder/tests/test_dynamo.py
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import pytest
import warnings
import torch
import torch.fx
import torch.nn as nn
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn.functional as F

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking
from thunder import last_traces
from thunder.core.symbol import Symbol
from thunder.tests.bf16 import device_supports_bf16
from thunder.tests.framework import (
instantiate,
Expand Down Expand Up @@ -535,3 +539,85 @@ def f(x):
)
compiled = torch.compile(backend=backend)(f)
compiled(x)


@requiresCUDA
@pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning")
def test_checkpoint_converter():
import torch.utils.checkpoint as checkpoint

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20)
self.layer2 = nn.Linear(20, 20)

def forward(self, x):
x = torch.sin(x)
x = checkpoint.checkpoint(self.layer1, x)
x = checkpoint.checkpoint(self.layer2, x)
x = F.relu(x)
return x

# Input tensor
x = torch.randn(5, 10).cuda().requires_grad_()
x_ref = x.detach().requires_grad_()

model = SimpleModel().cuda().train()
ref_model = SimpleModel().cuda().train()
ref_model.load_state_dict(model.state_dict())

backend = ThunderCompiler()
jf = torch.compile(backend=backend)(model)

ref_out = ref_model(x_ref)
out = jf(x)
torch.testing.assert_close(ref_out, out)

g = torch.randn_like(out)
out.backward(g)

ref_g = g.clone()
ref_out.backward(ref_g)
torch.testing.assert_close(x.grad, x_ref.grad)
torch.testing.assert_close(tuple(model.parameters()), tuple(ref_model.parameters()))


@requiresCUDA
def test_checkpoint_converter_submodule():
import torch.utils.checkpoint as checkpoint

class SubModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = nn.Sequential(nn.ReLU(), nn.Linear(10, 10))

def forward(self, x):
return self.lin(x)

class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.sub_mod = SubModule()

def forward(self, x):
x = torch.sin(x)
x = checkpoint.checkpoint(self.sub_mod, x)
return x

x = torch.randn(5, 10).cuda().requires_grad_()
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
model = SimpleModel().cuda().train()
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
backend = ThunderCompiler()
jf = torch.compile(backend=backend)(model)
out = jf(x)

subgraph_info = backend.subgraph_infos[0]
split_m = subgraph_info.split_graph_module
submodule_name = "wrap_body_0"
assert hasattr(split_m, submodule_name)

submodule = getattr(split_m, submodule_name)
kiya00 marked this conversation as resolved.
Show resolved Hide resolved

for n in submodule.graph.nodes:
if n.op == "call_function":
assert isinstance(n.target, Symbol)
4 changes: 2 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5262,8 +5262,8 @@ def _backward_checkpoint(
) -> tuple[None | TensorLike, ...]:
from thunder.core.transforms import vjp

result = vjp(function)(args, grad_outputs, **kwargs)
return result
result, grads = vjp(function)(args, grad_outputs, **kwargs)
return grads
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved


#
Expand Down
Loading