diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 4b455f60b6..587729dde3 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -138,6 +138,13 @@ def callback(node) -> int: original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) + + # Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275 + for submodule in original_split_gm.children(): + if not submodule.graph.find_nodes(op="output"): + submodule.graph.output(()) + if not original_split_gm.graph.find_nodes(op="output"): + original_split_gm.graph.output(()) split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index cc740ff408..dd03580991 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -449,10 +449,6 @@ def func(x): IS_WINDOWS, reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", ), - pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("2.6.0"), - reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275", - ), pytest.mark.skipif( version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"), reason="https://github.com/Lightning-AI/lightning-thunder/issues/1471", @@ -864,3 +860,28 @@ def forward(self, x): cmd = "pytest" if use_pytest_benchmark else "python" result1 = run([cmd, s1], capture_output=True, text=True) assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}" + + +def test_deepcopy_graph_module(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + 1 + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + n = gm.graph.find_nodes(op="output") + gm.graph.erase_node(n[0]) + import thunder + + _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, []) + original_split_gm = subgraph_info.original_split_graph_module + assert original_split_gm.graph.find_nodes(op="output") + for subm in original_split_gm.children(): + assert subm.graph.find_nodes(op="output") + import copy + + # No assertion error + copy_gm = copy.deepcopy(original_split_gm)