From 7c075bd9a04bb37f13180a99f5e42a9313347487 Mon Sep 17 00:00:00 2001 From: zmt5796-code Date: Wed, 17 Dec 2025 20:22:40 +0800 Subject: [PATCH] feat(agent): Implement automatic model extraction agent for Hackathon No.10 --- demo_agent.py | 30 +++++++++ docs/agent_design.md | 101 ++++++++++++++++++++++++++++++ graph_net/agent/__init__.py | 3 + graph_net/agent/analyzer.py | 43 +++++++++++++ graph_net/agent/coder/__init__.py | 1 + graph_net/agent/coder/template.py | 96 ++++++++++++++++++++++++++++ graph_net/agent/core.py | 81 ++++++++++++++++++++++++ graph_net/agent/extractor.py | 59 +++++++++++++++++ graph_net/agent/fetcher.py | 36 +++++++++++ graph_net/agent/verifier.py | 42 +++++++++++++ 10 files changed, 492 insertions(+) create mode 100644 demo_agent.py create mode 100644 docs/agent_design.md create mode 100644 graph_net/agent/__init__.py create mode 100644 graph_net/agent/analyzer.py create mode 100644 graph_net/agent/coder/__init__.py create mode 100644 graph_net/agent/coder/template.py create mode 100644 graph_net/agent/core.py create mode 100644 graph_net/agent/extractor.py create mode 100644 graph_net/agent/fetcher.py create mode 100644 graph_net/agent/verifier.py diff --git a/demo_agent.py b/demo_agent.py new file mode 100644 index 000000000..d2dea3b11 --- /dev/null +++ b/demo_agent.py @@ -0,0 +1,30 @@ +import os +import sys + +# Ensure we can import graph_net +sys.path.append(os.getcwd()) + +from graph_net.agent import GraphNetAgent + +def main(): + # Setup a local workspace + # Use a writable directory instead of System32 + workspace = os.path.join(os.path.dirname(os.path.abspath(__file__)), "agent_workspace") + print(f"Using workspace: {workspace}") + + agent = GraphNetAgent(workspace=workspace) + + # Use a small model for testing + test_model = "prajjwal1/bert-tiny" + + print(f"Processing model: {test_model}") + success = agent.process_model(test_model) + + if success: + print("\n[SUCCESS] Agent successfully processed the model!") + print(f"Check results in: {workspace}/downloads/{test_model.replace('/', '_')}/extracted_sample") + else: + print("\n[FAILURE] Agent failed to process the model.") + +if __name__ == "__main__": + main() diff --git a/docs/agent_design.md b/docs/agent_design.md new file mode 100644 index 000000000..fd71d1a3f --- /dev/null +++ b/docs/agent_design.md @@ -0,0 +1,101 @@ +# GraphNet 自动样本抽取 Agent 设计文档 + +## 1. 任务背景 +为了丰富 GraphNet 的样本库,我们需要从 Hugging Face (HF) 上自动下载模型,并将其转换为 GraphNet 可用的子图样本。目前这一过程需要人工编写 `run_model.py` 代码,效率较低。本 Agent 旨在自动化这一流程,实现从“HF 模型链接”到“GraphNet 样本提交”的端到端自动化。 + +## 2. 核心架构 +Agent 采用模块化设计,主要包含以下组件: + +### 2.1 架构图 +```mermaid +graph TD + User[用户输入: HF Model ID] --> Manager[Agent Manager] + Manager --> Fetcher[Model Fetcher] + Fetcher -- 下载模型 --> Local[本地模型文件] + Manager --> Analyzer[Model Analyzer] + Analyzer -- 分析 config.json --> Meta[模型元数据(Input Shape/Dtype)] + Manager --> Coder[Code Generator] + Meta --> Coder + Coder -- 生成代码 --> Script[run_model.py] + Manager --> Extractor[Graph Extractor] + Script --> Extractor + Extractor -- 运行 & 抽图 --> Sample[GraphNet Sample] + Manager --> Verifier[Sample Verifier] + Sample --> Verifier + Verifier -- 验证通过 --> Git[Git Submitter] +``` + +### 2.2 模块说明 + +#### 1. Model Fetcher (`agent.fetcher`) +- **功能**: 调用 `huggingface_hub` 下载模型快照。 +- **输入**: `model_id` (e.g., `bert-base-uncased`) +- **输出**: 本地路径。 + +#### 2. Model Analyzer (`agent.analyzer`) +- **功能**: 解析模型目录下的 `config.json` 或 `README.md`。 +- **目标**: 推断模型的 `input_shape` 和 `input_dtype`。例如 BERT 通常需要 `input_ids` [batch, seq_len] (int64)。 + +#### 3. Code Generator (`agent.coder`) +- **功能**: 生成 `run_model.py`。 +- **策略**: + - **Template Mode**: 针对常见架构(如 Bert, ResNet, GPT)使用预定义模板。 + - **LLM Mode (可选)**: 调用外部 LLM API 生成代码(预留接口)。 + +#### 4. Graph Extractor (`agent.extractor`) +- **功能**: 在子进程中运行生成的 `run_model.py`。 +- **依赖**: 复用 `graph_net.torch.run_model` 或直接调用脚本。 + +#### 5. Sample Verifier (`agent.verifier`) +- **功能**: 检查生成的 `graph_net.json`, `model.py`, `input_meta.py` 是否存在且合法。 + +## 3. 接口设计 + +### `GraphNetAgent` 类 +```python +class GraphNetAgent: + def __init__(self, workspace: str, hf_token: str = None): + self.workspace = workspace + self.fetcher = HFFetcher(token=hf_token) + self.analyzer = ConfigAnalyzer() + self.coder = TemplateCoder() + self.extractor = SubprocessExtractor() + self.verifier = BasicVerifier() + + def run(self, model_id: str) -> bool: + # 1. Download + model_dir = self.fetcher.download(model_id) + + # 2. Analyze + meta_info = self.analyzer.analyze(model_dir) + + # 3. Generate Code + code_path = self.coder.generate(model_dir, meta_info) + + # 4. Extract + output_dir = self.extractor.extract(code_path) + + # 5. Verify + return self.verifier.verify(output_dir) +``` + +## 4. 目录结构 +```text +graph_net/ + agent/ + __init__.py + core.py # Agent 主逻辑 + fetcher.py # 下载模块 + analyzer.py # 分析模块 + coder/ + base.py + template.py # 模板生成 + llm.py # LLM 生成 (Interface) + extractor.py # 运行模块 + verifier.py # 验证模块 +``` + +## 5. 扩展性计划 +- 支持更多的 HF 任务类型(NLP, CV, Audio)。 +- 接入 DeepSeek/OpenAI API 提升代码生成成功率。 +- 自动化 PR 提交功能。 diff --git a/graph_net/agent/__init__.py b/graph_net/agent/__init__.py new file mode 100644 index 000000000..ea276ede1 --- /dev/null +++ b/graph_net/agent/__init__.py @@ -0,0 +1,3 @@ +from .core import GraphNetAgent + +__all__ = ["GraphNetAgent"] diff --git a/graph_net/agent/analyzer.py b/graph_net/agent/analyzer.py new file mode 100644 index 000000000..893e661e6 --- /dev/null +++ b/graph_net/agent/analyzer.py @@ -0,0 +1,43 @@ +import os +import json +import logging +from typing import Dict, Any + +class ConfigAnalyzer: + def __init__(self): + self.logger = logging.getLogger("ConfigAnalyzer") + + def analyze(self, model_dir: str) -> Dict[str, Any]: + """ + Analyze config.json to infer input specifications. + """ + config_path = os.path.join(model_dir, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"config.json not found in {model_dir}") + + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + + architecture = config.get("architectures", ["Unknown"])[0] + self.logger.info(f"Detected architecture: {architecture}") + + # Heuristic rules + meta_info = { + "architecture": architecture, + "input_shape": [1, 128], # Default batch size 1, seq len 128 + "input_dtype": "int64", + "task_type": "nlp" + } + + # Refine based on architecture + if "Bert" in architecture or "Roberta" in architecture: + meta_info["input_names"] = ["input_ids", "attention_mask", "token_type_ids"] + elif "Gpt" in architecture or "Llama" in architecture: + meta_info["input_names"] = ["input_ids", "attention_mask"] + elif "ResNet" in architecture or "ViT" in architecture: + meta_info["task_type"] = "cv" + meta_info["input_shape"] = [1, 3, 224, 224] + meta_info["input_dtype"] = "float32" + meta_info["input_names"] = ["pixel_values"] + + return meta_info diff --git a/graph_net/agent/coder/__init__.py b/graph_net/agent/coder/__init__.py new file mode 100644 index 000000000..d08802354 --- /dev/null +++ b/graph_net/agent/coder/__init__.py @@ -0,0 +1 @@ +from .template import TemplateCoder diff --git a/graph_net/agent/coder/template.py b/graph_net/agent/coder/template.py new file mode 100644 index 000000000..4755b1d2c --- /dev/null +++ b/graph_net/agent/coder/template.py @@ -0,0 +1,96 @@ +import os +import logging +from typing import Dict, Any + +class TemplateCoder: + def __init__(self): + self.logger = logging.getLogger("TemplateCoder") + + def generate(self, model_dir: str, meta_info: Dict[str, Any]) -> str: + """ + Generate a python script to load the model and run extraction. + """ + script_content = self._create_script_content(model_dir, meta_info) + + output_path = os.path.join(model_dir, "run_extraction.py") + with open(output_path, "w", encoding="utf-8") as f: + f.write(script_content) + + return output_path + + def _create_script_content(self, model_dir: str, meta_info: Dict[str, Any]) -> str: + # Basic template for HF models + input_names = meta_info.get("input_names", ["input_ids"]) + input_shape = meta_info.get("input_shape", [1, 128]) + input_dtype = meta_info.get("input_dtype", "int64") + + # Construct input generation code + input_gen_code = "" + if meta_info["task_type"] == "nlp": + input_gen_code += f""" + # NLP Inputs + input_ids = torch.randint(0, 100, {tuple(input_shape)}, dtype=torch.int64) + attention_mask = torch.ones({tuple(input_shape)}, dtype=torch.int64) + inputs = (input_ids, attention_mask) + """ + elif meta_info["task_type"] == "cv": + input_gen_code += f""" + # CV Inputs + inputs = (torch.randn({tuple(input_shape)}, dtype=torch.float32),) + """ + + template = f""" +import sys +import os +import torch +from transformers import AutoModel, AutoConfig + +# Ensure graph_net is in path +sys.path.append(os.getcwd()) + +def main(): + model_path = r"{model_dir}" + output_dir = r"{model_dir}/extracted_sample" + + print(f"Loading model from {{model_path}}...") + try: + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + model.eval() + except Exception as e: + print(f"Failed to load model: {{e}}") + sys.exit(1) + + print("Generating inputs...") + {input_gen_code} + + # Move to CUDA if available + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + inputs = tuple(t.to(device) for t in inputs) + + print("Starting extraction...") + # Setup environment variable for GraphNet workspace + os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = output_dir + + # Use the extract API from graph_net + # extract(name, dynamic=True)(model) returns a compiled model + # We need to run it once to trigger compilation and extraction + from graph_net.torch.extractor import extract + + compiled_model = extract(name="subgraph", dynamic=True)(model) + + print("Running forward pass to trigger extraction...") + with torch.no_grad(): + if isinstance(inputs, tuple): + compiled_model(*inputs) + elif isinstance(inputs, dict): + compiled_model(**inputs) + else: + compiled_model(inputs) + + print(f"Extraction complete. Results in {{output_dir}}") + +if __name__ == "__main__": + main() +""" + return template diff --git a/graph_net/agent/core.py b/graph_net/agent/core.py new file mode 100644 index 000000000..9efbc0486 --- /dev/null +++ b/graph_net/agent/core.py @@ -0,0 +1,81 @@ +import os +import logging +from typing import Optional + +from .fetcher import HFFetcher +from .analyzer import ConfigAnalyzer +from .coder.template import TemplateCoder +from .extractor import SubprocessExtractor +from .verifier import BasicVerifier + +class GraphNetAgent: + def __init__(self, workspace: str, hf_token: Optional[str] = None): + """ + Initialize the GraphNet Agent. + + Args: + workspace (str): Directory where models and samples will be stored. + hf_token (str, optional): Hugging Face API token. + """ + self.workspace = os.path.abspath(workspace) + os.makedirs(self.workspace, exist_ok=True) + + self.logger = logging.getLogger("GraphNetAgent") + self.logger.setLevel(logging.INFO) + + # Initialize components + self.fetcher = HFFetcher(self.workspace, token=hf_token) + self.analyzer = ConfigAnalyzer() + self.coder = TemplateCoder() + self.extractor = SubprocessExtractor(self.workspace) + self.verifier = BasicVerifier() + + def process_model(self, model_id: str) -> bool: + """ + Process a single model: Download -> Analyze -> Generate Code -> Extract -> Verify. + + Args: + model_id (str): Hugging Face model ID (e.g. 'bert-base-uncased') + + Returns: + bool: True if successful, False otherwise. + """ + self.logger.info(f"Starting process for model: {model_id}") + + try: + # 1. Download Model + self.logger.info("Step 1: Downloading model...") + model_dir = self.fetcher.download(model_id) + self.logger.info(f"Model downloaded to: {model_dir}") + + # 2. Analyze Model Config + self.logger.info("Step 2: Analyzing model config...") + meta_info = self.analyzer.analyze(model_dir) + self.logger.info(f"Analysis result: {meta_info}") + + # 3. Generate Running Script + self.logger.info("Step 3: Generating run_model.py...") + script_path = self.coder.generate(model_dir, meta_info) + self.logger.info(f"Script generated at: {script_path}") + + # 4. Extract Subgraph + self.logger.info("Step 4: Extracting subgraph...") + output_dir = self.extractor.extract(script_path, model_id) + self.logger.info(f"Extraction output dir: {output_dir}") + + # 5. Verify Result + self.logger.info("Step 5: Verifying result...") + is_valid = self.verifier.verify(output_dir) + + if is_valid: + self.logger.info(f"SUCCESS: Model {model_id} processed successfully.") + return True + else: + self.logger.error(f"FAILURE: Verification failed for {model_id}.") + return False + + except Exception as e: + self.logger.error(f"Error processing {model_id}: {str(e)}") + import traceback + self.logger.error(traceback.format_exc()) + return False diff --git a/graph_net/agent/extractor.py b/graph_net/agent/extractor.py new file mode 100644 index 000000000..5c9acda1a --- /dev/null +++ b/graph_net/agent/extractor.py @@ -0,0 +1,59 @@ +import os +import subprocess +import logging +import sys + +class SubprocessExtractor: + def __init__(self, workspace: str): + self.workspace = workspace + self.logger = logging.getLogger("SubprocessExtractor") + + def extract(self, script_path: str, model_id: str) -> str: + """ + Run the extraction script in a subprocess. + """ + # Define where the output is expected to be + # The script template defines it as {model_dir}/extracted_sample + model_dir = os.path.dirname(script_path) + expected_output_dir = os.path.join(model_dir, "extracted_sample") + + cmd = [sys.executable, script_path] + + self.logger.info(f"Executing: {' '.join(cmd)}") + + # Determine cwd (should be GraphNet root to allow imports) + # Assuming self.workspace/GraphNet/graph_net/agent/../../.. + # But here we assume user sets up PYTHONPATH or we run from root. + # We try to find 'graph_net' package root. + # graph_net is in D:\playground\task_10_agent\GraphNet + # We need to add D:\playground\task_10_agent\GraphNet to PYTHONPATH + + # Heuristic: assume this file is in graph_net/agent/extractor.py + # root is 3 levels up + current_file = os.path.abspath(__file__) + root_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_file))) + + env = os.environ.copy() + env["PYTHONPATH"] = env.get("PYTHONPATH", "") + os.pathsep + root_dir + + try: + result = subprocess.run( + cmd, + cwd=root_dir, # Run from root + env=env, + capture_output=True, + text=True, + check=True + ) + self.logger.info("Extraction script finished successfully.") + self.logger.debug(result.stdout) + except subprocess.CalledProcessError as e: + self.logger.error("Extraction script failed.") + self.logger.error(f"STDOUT: {e.stdout}") + self.logger.error(f"STDERR: {e.stderr}") + raise RuntimeError(f"Extraction failed for {model_id}") + + if not os.path.exists(expected_output_dir): + raise FileNotFoundError(f"Expected output directory {expected_output_dir} was not created.") + + return expected_output_dir diff --git a/graph_net/agent/fetcher.py b/graph_net/agent/fetcher.py new file mode 100644 index 000000000..b277168a3 --- /dev/null +++ b/graph_net/agent/fetcher.py @@ -0,0 +1,36 @@ +import os +import logging +from typing import Optional + +class HFFetcher: + def __init__(self, workspace: str, token: Optional[str] = None): + self.workspace = workspace + self.token = token + self.logger = logging.getLogger("HFFetcher") + + def download(self, model_id: str) -> str: + """ + Download model snapshot from Hugging Face. + """ + try: + from huggingface_hub import snapshot_download + except ImportError: + raise ImportError("huggingface_hub is not installed. Please install it with `pip install huggingface_hub`.") + + self.logger.info(f"Downloading {model_id} from Hugging Face...") + + # Define local dir based on model_id + local_dir = os.path.join(self.workspace, "downloads", model_id.replace("/", "_")) + + # Download only necessary files to save time/bandwidth + allow_patterns = ["*.json", "*.bin", "*.safetensors", "*.py", "*.txt", "*.md"] + + snapshot_download( + repo_id=model_id, + local_dir=local_dir, + token=self.token, + allow_patterns=allow_patterns, + local_dir_use_symlinks=False # Copy files for easier manipulation + ) + + return local_dir diff --git a/graph_net/agent/verifier.py b/graph_net/agent/verifier.py new file mode 100644 index 000000000..28692fdba --- /dev/null +++ b/graph_net/agent/verifier.py @@ -0,0 +1,42 @@ +import os +import logging +import json + +class BasicVerifier: + def __init__(self): + self.logger = logging.getLogger("BasicVerifier") + + def verify(self, sample_dir: str) -> bool: + """ + Verify the integrity of the extracted sample. + """ + # GraphNet might generate subdirectories like 'subgraph' or 'subgraph_1' + # We need to find the actual model directory. + target_dir = sample_dir + if os.path.exists(os.path.join(sample_dir, "subgraph")): + target_dir = os.path.join(sample_dir, "subgraph") + + required_files = ["graph_net.json", "model.py", "input_meta.py"] + + missing = [] + for f in required_files: + if not os.path.exists(os.path.join(target_dir, f)): + missing.append(f) + + if missing: + self.logger.error(f"Verification Failed: Missing files {missing} in {target_dir}") + return False + + # Optional: Check graph_net.json validity + try: + with open(os.path.join(target_dir, "graph_net.json"), "r") as f: + config = json.load(f) + if "hash" not in config and "model_hash" not in config: + # Just a warning or strict check + pass + except Exception as e: + self.logger.error(f"Verification Failed: Invalid json - {e}") + return False + + self.logger.info("Verification Passed.") + return True