diff --git a/graph_net/test/backward_graph_extractor_test.sh b/graph_net/test/backward_graph_extractor_test.sh new file mode 100644 index 000000000..f71428659 --- /dev/null +++ b/graph_net/test/backward_graph_extractor_test.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") +MODEL_PATH_PREFIX=$GRAPH_NET_ROOT +OUTPUT_DIR=/tmp/backward_graph_workspace +FRAMEWORK="torch" +HANDLER_CONFIG=$(base64 -w 0 < bool: + return self.naive_sample_handled(rel_model_path, search_file_name="model.py") + + def resume(self, rel_model_path: str): + model_path_prefix = Path(self.config["model_path_prefix"]) + model_name = f"{os.path.basename(rel_model_path)}_backward" + model_path = model_path_prefix / rel_model_path + output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path) + device = self._choose_device(self.config["device"]) + extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device) + extractor() + + def _choose_device(self, device) -> str: + if device in ["cpu", "cuda"]: + return device + return "cuda" if torch.cuda.is_available() else "cpu"