Skip to content

Commit

Permalink
Add output node if it does not exist in the split module (#1480)
Browse files Browse the repository at this point in the history
Fixes #1476.

Add a workaround to make ThunderFX work with an older version of PyTorch by going through all submodules of split_module and adding an output node if it's missing.
  • Loading branch information
kiya00 authored Dec 2, 2024
1 parent e0d1494 commit 15c48ef
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
7 changes: 7 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def callback(node) -> int:
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

# Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275
for submodule in original_split_gm.children():
if not submodule.graph.find_nodes(op="output"):
submodule.graph.output(())
if not original_split_gm.graph.find_nodes(op="output"):
original_split_gm.graph.output(())
split_gm = copy.deepcopy(original_split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
Expand Down
29 changes: 25 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,6 @@ def func(x):
IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
),
pytest.mark.skipif(
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/Lightning-AI/lightning-thunder/issues/1471",
Expand Down Expand Up @@ -864,3 +860,28 @@ def forward(self, x):
cmd = "pytest" if use_pytest_benchmark else "python"
result1 = run([cmd, s1], capture_output=True, text=True)
assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}"


def test_deepcopy_graph_module():
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
y = x + 1

m = MyModule()
gm = torch.fx.symbolic_trace(m)
n = gm.graph.find_nodes(op="output")
gm.graph.erase_node(n[0])
import thunder

_, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, [])
original_split_gm = subgraph_info.original_split_graph_module
assert original_split_gm.graph.find_nodes(op="output")
for subm in original_split_gm.children():
assert subm.graph.find_nodes(op="output")
import copy

# No assertion error
copy_gm = copy.deepcopy(original_split_gm)

0 comments on commit 15c48ef

Please sign in to comment.