diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index c267f63008..7440684e7a 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -148,14 +148,14 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: for node in split_gm.graph.nodes: if is_thunder_supported_partition(node): graph_module = getattr(split_gm, node.name) - # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators - checkpoint_converter(split_gm, graph_module) # Record the input tensor metadata of the current module based on the faketensor 'example_value' of the placeholder node placeholders = list(n for n in graph_module.graph.nodes if n.op == "placeholder") example_input_metadata = map( partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders ) example_input_metadatas.append(list(example_input_metadata)) + # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators + checkpoint_converter(split_gm, graph_module) jit_fn = thunder_jit(graph_module) # Update the node name from "submod_*" to "thunder_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index b4cc29a5e8..493952e5f6 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -719,6 +719,8 @@ def reproducer( print(torch_env, file=f) print("\nVersions of Thunder related libraries:", file=f) print(thunder_pkgs, file=f) + print("\nThe torch.fx.Graph:", file=f) + print(gm.graph, file=f) print('"""', file=f) print("import os\n", file=f) print("import torch", file=f)