diff --git a/graph_net/test/typical_sequence_decomposer_test.sh b/graph_net/test/typical_sequence_decomposer_test.sh index ca327ba31..c28ec1334 100644 --- a/graph_net/test/typical_sequence_decomposer_test.sh +++ b/graph_net/test/typical_sequence_decomposer_test.sh @@ -14,7 +14,7 @@ python3 -m graph_net.model_path_handler \ --model-path-list $model_list \ --handler-config=$(base64 -w 0 < bool: + return self.naive_sample_handled( + rel_model_path, search_file_name="op_names.txt" + ) + + def resume(self, rel_model_path: str): + torch.cuda.empty_cache() + model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) + op_names = self._extract_ops(model_path) + output_path = self._get_output_path(rel_model_path) + output_path.write_text("\n".join(op_names)) + print(f"Save op-names to {str(output_path)}") + + def __call__(self, rel_model_path: str): + self.resumable_handle_sample(rel_model_path) + + def _make_config( + self, model_path_prefix: str, output_dir: str, resume: bool = False + ): + return { + "model_path_prefix": model_path_prefix, + "resume": resume, + "output_dir": output_dir, + } + + def _get_output_path(self, rel_model_path: str): + output_path_dir = Path(self.config["output_dir"]) / rel_model_path + output_path_dir.mkdir(parents=True, exist_ok=True) + output_path = output_path_dir / "op_names.txt" + return output_path + + def _extract_ops(self, model_path: str) -> List[str]: + extractor = TypicalSequenceExtractor() + model, inputs = get_torch_module_and_inputs(model_path) + compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs) + extractor.extract_compiler(compiled_model, inputs) + ops_info = extractor.extract_node + + return [op["target_name"] for op in ops_info] + class TypicalSequenceExtractor: def __init__(self): @@ -52,46 +105,3 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]) operator = self._extract_operators_from_graph(gm, inputs) self.extract_node = operator return gm.forward - - - -class OpNamesExtractor: - def __init__(self, config=None): - if config is None: - config = {} - - self.config = self._make_config(**config) - - def _make_config( - self, model_path_prefix: str, output_dir: str, resume: bool = False - ): - return { - "model_path_prefix": model_path_prefix, - "resume": resume, - "output_dir": output_dir, - } - - def __call__(self, rel_model_path: str): - torch.cuda.empty_cache() - model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) - output_path = self._get_output_path(rel_model_path) - if self.config["resume"] and output_path.exists(): - return - op_names = self._extract_ops(model_path) - output_path.write_text("\n".join(op_names)) - print(f"Save op-names to {str(output_path)}") - - def _get_output_path(self, rel_model_path: str): - output_path_dir = Path(self.config["output_dir"]) / rel_model_path - output_path_dir.mkdir(parents=True, exist_ok=True) - output_path = output_path_dir / "op_names.txt" - return output_path - - def _extract_ops(self, model_path: str) -> List[str]: - extractor = TypicalSequenceExtractor() - model, inputs = get_torch_module_and_inputs(model_path) - compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs) - extractor.extract_compiler(compiled_model, inputs) - ops_info = extractor.extract_node - - return [op["target_name"] for op in ops_info]