From 40bafbcbf9157d706448121f6b738c1234352c12 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 24 Dec 2025 11:47:40 +0800 Subject: [PATCH 1/2] Implement BackwardGraphExtractor for torch. --- .../test/backward_graph_extractor_test.sh | 29 ++++ .../sample_passes/backward_graph_extractor.py | 148 ++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 graph_net/test/backward_graph_extractor_test.sh create mode 100644 graph_net/torch/sample_passes/backward_graph_extractor.py diff --git a/graph_net/test/backward_graph_extractor_test.sh b/graph_net/test/backward_graph_extractor_test.sh new file mode 100644 index 000000000..3cc260e0d --- /dev/null +++ b/graph_net/test/backward_graph_extractor_test.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") +MODEL_PATH_PREFIX=$GRAPH_NET_ROOT +OUTPUT_DIR=/tmp/backward_graph_workspace +FRAMEWORK="torch" +HANDLER_CONFIG=$(base64 -w 0 < bool: + return self.naive_sample_handled(rel_model_path, search_file_name="model.py") + + def resume(self, rel_model_path: str): + model_path_prefix = Path(self.config["model_path_prefix"]) + model_name = f"{os.path.basename(rel_model_path)}_backward" + model_path = model_path_prefix / rel_model_path + output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path) + device = self._choose_device(self.config["device"]) + extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device) + extractor() + + def _choose_device(self, device) -> str: + if device in ["cpu", "cuda"]: + return device + return "cuda" if torch.cuda.is_available() else "cpu" From de2a081300d3f285d39a08f693cf46e73ac50b4e Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 24 Dec 2025 13:26:47 +0800 Subject: [PATCH 2/2] Remove None from returns. --- .../sample_passes/backward_graph_extractor.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/graph_net/torch/sample_passes/backward_graph_extractor.py b/graph_net/torch/sample_passes/backward_graph_extractor.py index 384403a03..1e5c27e99 100644 --- a/graph_net/torch/sample_passes/backward_graph_extractor.py +++ b/graph_net/torch/sample_passes/backward_graph_extractor.py @@ -37,7 +37,7 @@ def __call__(self): self.model_path, module, example_inputs ) bw_gm, backward_inputs = self.capture_backward_graph(module, example_inputs) - print(bw_gm.graph) + # print(bw_gm.graph) self.builtin_extractor(bw_gm, backward_inputs) def capture_backward_graph(self, module, example_inputs): @@ -76,7 +76,28 @@ def wrapped_forward(*args): outs_grad = [torch.ones_like(out) for out in outs] torch.autograd.backward(outs, outs_grad) - return backward_gm_holder["gm"], backward_inputs + bw_gm = self._remove_none_from_output(backward_gm_holder["gm"]) + return bw_gm, backward_inputs + + def _remove_none_from_output(self, gm): + output_node = next( + (n for n in gm.graph.nodes if n.op == "output"), + None, + ) + outs = ( + output_node.args[0] + if output_node and isinstance(output_node.args, (tuple, list)) + else output_node.args + ) + if isinstance(outs, (tuple, list)): + new_outs = tuple(out for out in outs if out is not None) + if new_outs != outs: + output_node.args = (new_outs,) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm def _requires_grad(self, name, tensor): if not tensor.is_floating_point():