From 7951a51c2676128a11d6de188a244e39c2f66583 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 26 Nov 2024 21:27:39 +0100 Subject: [PATCH 1/4] Add output node if it does not exist in the split module (#1476) --- thunder/dynamo/splitter.py | 13 +++++++++++++ thunder/tests/test_dynamo.py | 4 ---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index b128357b97..f52dff0f7b 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -135,6 +135,19 @@ def callback(node) -> int: original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) + + def add_output(m): + has_output = False + for node in m.graph.nodes: + if node.op == "call_module": + add_output(getattr(m, node.target)) + elif node.op == "output": + has_output = True + if not has_output: + m.graph.output(()) + + # Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275 + add_output(original_split_gm) split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 65e6603f54..a580362c91 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -447,10 +447,6 @@ def func(x): IS_WINDOWS, reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", ), - pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("2.6.0"), - reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275", - ), pytest.mark.skipif( version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"), reason="https://github.com/Lightning-AI/lightning-thunder/issues/1471", From 39a4fd1aa9908f50e51ae768f2ca035f4622cac2 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 27 Nov 2024 09:48:59 +0100 Subject: [PATCH 2/4] fix --- thunder/dynamo/splitter.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index f52dff0f7b..435098028b 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -136,18 +136,11 @@ def callback(node) -> int: gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) - def add_output(m): - has_output = False - for node in m.graph.nodes: - if node.op == "call_module": - add_output(getattr(m, node.target)) - elif node.op == "output": - has_output = True - if not has_output: - m.graph.output(()) - # Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275 - add_output(original_split_gm) + for submodule in original_split_gm.children(): + last_node = next(iter(reversed(submodule.graph.nodes))) + if last_node.op != "output": + submodule.graph.output(()) split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: From 909a1a93afb45898c3637840216c0555a5fb81a9 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 27 Nov 2024 18:35:45 +0100 Subject: [PATCH 3/4] add test --- thunder/dynamo/splitter.py | 5 +++-- thunder/tests/test_dynamo.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 65aa0bd70b..587729dde3 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -141,9 +141,10 @@ def callback(node) -> int: # Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275 for submodule in original_split_gm.children(): - last_node = next(iter(reversed(submodule.graph.nodes))) - if last_node.op != "output": + if not submodule.graph.find_nodes(op="output"): submodule.graph.output(()) + if not original_split_gm.graph.find_nodes(op="output"): + original_split_gm.graph.output(()) split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index a28187502c..8522ef253f 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -860,3 +860,22 @@ def forward(self, x): cmd = "pytest" if use_pytest_benchmark else "python" result1 = run([cmd, s1], capture_output=True, text=True) assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}" + + +def test_deepcopy_graph_module(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + 1 + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + n = gm.graph.find_nodes(op="output") + gm.graph.erase_node(n[0]) + import thunder + + _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, []) + original_split_gm = subgraph_info.original_split_graph_module + assert original_split_gm.graph.find_nodes(op="output") From 4bd3375ecb2d253fb7deb0a68f5ed3f5d907f5e9 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 27 Nov 2024 18:59:42 +0100 Subject: [PATCH 4/4] fix --- thunder/tests/test_dynamo.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 8522ef253f..dd03580991 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -879,3 +879,9 @@ def forward(self, x): _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, []) original_split_gm = subgraph_info.original_split_graph_module assert original_split_gm.graph.find_nodes(op="output") + for subm in original_split_gm.children(): + assert subm.graph.find_nodes(op="output") + import copy + + # No assertion error + copy_gm = copy.deepcopy(original_split_gm)