Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add output node if it does not exist in the split module #1480

Merged
merged 7 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def callback(node) -> int:
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

# Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275
for submodule in original_split_gm.children():
if not submodule.graph.find_nodes(op="output"):
submodule.graph.output(())
if not original_split_gm.graph.find_nodes(op="output"):
original_split_gm.graph.output(())
split_gm = copy.deepcopy(original_split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
Expand Down
29 changes: 25 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,6 @@ def func(x):
IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
),
pytest.mark.skipif(
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/Lightning-AI/lightning-thunder/issues/1471",
Expand Down Expand Up @@ -864,3 +860,28 @@ def forward(self, x):
cmd = "pytest" if use_pytest_benchmark else "python"
result1 = run([cmd, s1], capture_output=True, text=True)
assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}"


def test_deepcopy_graph_module():
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
y = x + 1

m = MyModule()
gm = torch.fx.symbolic_trace(m)
n = gm.graph.find_nodes(op="output")
gm.graph.erase_node(n[0])
import thunder

_, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, [])
original_split_gm = subgraph_info.original_split_graph_module
assert original_split_gm.graph.find_nodes(op="output")
for subm in original_split_gm.children():
assert subm.graph.find_nodes(op="output")
import copy

# No assertion error
copy_gm = copy.deepcopy(original_split_gm)
Loading