Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions demo_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import sys

# Ensure we can import graph_net
sys.path.append(os.getcwd())

from graph_net.agent import GraphNetAgent

def main():
# Setup a local workspace
# Use a writable directory instead of System32
workspace = os.path.join(os.path.dirname(os.path.abspath(__file__)), "agent_workspace")
print(f"Using workspace: {workspace}")

agent = GraphNetAgent(workspace=workspace)

# Use a small model for testing
test_model = "prajjwal1/bert-tiny"

print(f"Processing model: {test_model}")
success = agent.process_model(test_model)

if success:
print("\n[SUCCESS] Agent successfully processed the model!")
print(f"Check results in: {workspace}/downloads/{test_model.replace('/', '_')}/extracted_sample")
else:
print("\n[FAILURE] Agent failed to process the model.")

if __name__ == "__main__":
main()
101 changes: 101 additions & 0 deletions docs/agent_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# GraphNet 自动样本抽取 Agent 设计文档

## 1. 任务背景
为了丰富 GraphNet 的样本库,我们需要从 Hugging Face (HF) 上自动下载模型,并将其转换为 GraphNet 可用的子图样本。目前这一过程需要人工编写 `run_model.py` 代码,效率较低。本 Agent 旨在自动化这一流程,实现从“HF 模型链接”到“GraphNet 样本提交”的端到端自动化。

## 2. 核心架构
Agent 采用模块化设计,主要包含以下组件:

### 2.1 架构图
```mermaid
graph TD
User[用户输入: HF Model ID] --> Manager[Agent Manager]
Manager --> Fetcher[Model Fetcher]
Fetcher -- 下载模型 --> Local[本地模型文件]
Manager --> Analyzer[Model Analyzer]
Analyzer -- 分析 config.json --> Meta[模型元数据(Input Shape/Dtype)]
Manager --> Coder[Code Generator]
Meta --> Coder
Coder -- 生成代码 --> Script[run_model.py]
Manager --> Extractor[Graph Extractor]
Script --> Extractor
Extractor -- 运行 & 抽图 --> Sample[GraphNet Sample]
Manager --> Verifier[Sample Verifier]
Sample --> Verifier
Verifier -- 验证通过 --> Git[Git Submitter]
```

### 2.2 模块说明

#### 1. Model Fetcher (`agent.fetcher`)
- **功能**: 调用 `huggingface_hub` 下载模型快照。
- **输入**: `model_id` (e.g., `bert-base-uncased`)
- **输出**: 本地路径。

#### 2. Model Analyzer (`agent.analyzer`)
- **功能**: 解析模型目录下的 `config.json` 或 `README.md`。
- **目标**: 推断模型的 `input_shape` 和 `input_dtype`。例如 BERT 通常需要 `input_ids` [batch, seq_len] (int64)。

#### 3. Code Generator (`agent.coder`)
- **功能**: 生成 `run_model.py`。
- **策略**:
- **Template Mode**: 针对常见架构(如 Bert, ResNet, GPT)使用预定义模板。
- **LLM Mode (可选)**: 调用外部 LLM API 生成代码(预留接口)。

#### 4. Graph Extractor (`agent.extractor`)
- **功能**: 在子进程中运行生成的 `run_model.py`。
- **依赖**: 复用 `graph_net.torch.run_model` 或直接调用脚本。

#### 5. Sample Verifier (`agent.verifier`)
- **功能**: 检查生成的 `graph_net.json`, `model.py`, `input_meta.py` 是否存在且合法。

## 3. 接口设计

### `GraphNetAgent` 类
```python
class GraphNetAgent:
def __init__(self, workspace: str, hf_token: str = None):
self.workspace = workspace
self.fetcher = HFFetcher(token=hf_token)
self.analyzer = ConfigAnalyzer()
self.coder = TemplateCoder()
self.extractor = SubprocessExtractor()
self.verifier = BasicVerifier()

def run(self, model_id: str) -> bool:
# 1. Download
model_dir = self.fetcher.download(model_id)

# 2. Analyze
meta_info = self.analyzer.analyze(model_dir)

# 3. Generate Code
code_path = self.coder.generate(model_dir, meta_info)

# 4. Extract
output_dir = self.extractor.extract(code_path)

# 5. Verify
return self.verifier.verify(output_dir)
```

## 4. 目录结构
```text
graph_net/
agent/
__init__.py
core.py # Agent 主逻辑
fetcher.py # 下载模块
analyzer.py # 分析模块
coder/
base.py
template.py # 模板生成
llm.py # LLM 生成 (Interface)
extractor.py # 运行模块
verifier.py # 验证模块
```

## 5. 扩展性计划
- 支持更多的 HF 任务类型(NLP, CV, Audio)。
- 接入 DeepSeek/OpenAI API 提升代码生成成功率。
- 自动化 PR 提交功能。
3 changes: 3 additions & 0 deletions graph_net/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .core import GraphNetAgent

__all__ = ["GraphNetAgent"]
43 changes: 43 additions & 0 deletions graph_net/agent/analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import json
import logging
from typing import Dict, Any

class ConfigAnalyzer:
def __init__(self):
self.logger = logging.getLogger("ConfigAnalyzer")

def analyze(self, model_dir: str) -> Dict[str, Any]:
"""
Analyze config.json to infer input specifications.
"""
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"config.json not found in {model_dir}")

with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)

architecture = config.get("architectures", ["Unknown"])[0]
self.logger.info(f"Detected architecture: {architecture}")

# Heuristic rules
meta_info = {
"architecture": architecture,
"input_shape": [1, 128], # Default batch size 1, seq len 128
"input_dtype": "int64",
"task_type": "nlp"
}

# Refine based on architecture
if "Bert" in architecture or "Roberta" in architecture:
meta_info["input_names"] = ["input_ids", "attention_mask", "token_type_ids"]
elif "Gpt" in architecture or "Llama" in architecture:
meta_info["input_names"] = ["input_ids", "attention_mask"]
elif "ResNet" in architecture or "ViT" in architecture:
meta_info["task_type"] = "cv"
meta_info["input_shape"] = [1, 3, 224, 224]
meta_info["input_dtype"] = "float32"
meta_info["input_names"] = ["pixel_values"]

return meta_info
1 change: 1 addition & 0 deletions graph_net/agent/coder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .template import TemplateCoder
96 changes: 96 additions & 0 deletions graph_net/agent/coder/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import logging
from typing import Dict, Any

class TemplateCoder:
def __init__(self):
self.logger = logging.getLogger("TemplateCoder")

def generate(self, model_dir: str, meta_info: Dict[str, Any]) -> str:
"""
Generate a python script to load the model and run extraction.
"""
script_content = self._create_script_content(model_dir, meta_info)

output_path = os.path.join(model_dir, "run_extraction.py")
with open(output_path, "w", encoding="utf-8") as f:
f.write(script_content)

return output_path

def _create_script_content(self, model_dir: str, meta_info: Dict[str, Any]) -> str:
# Basic template for HF models
input_names = meta_info.get("input_names", ["input_ids"])
input_shape = meta_info.get("input_shape", [1, 128])
input_dtype = meta_info.get("input_dtype", "int64")

# Construct input generation code
input_gen_code = ""
if meta_info["task_type"] == "nlp":
input_gen_code += f"""
# NLP Inputs
input_ids = torch.randint(0, 100, {tuple(input_shape)}, dtype=torch.int64)
attention_mask = torch.ones({tuple(input_shape)}, dtype=torch.int64)
inputs = (input_ids, attention_mask)
"""
elif meta_info["task_type"] == "cv":
input_gen_code += f"""
# CV Inputs
inputs = (torch.randn({tuple(input_shape)}, dtype=torch.float32),)
"""

template = f"""
import sys
import os
import torch
from transformers import AutoModel, AutoConfig

# Ensure graph_net is in path
sys.path.append(os.getcwd())

def main():
model_path = r"{model_dir}"
output_dir = r"{model_dir}/extracted_sample"

print(f"Loading model from {{model_path}}...")
try:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
except Exception as e:
print(f"Failed to load model: {{e}}")
sys.exit(1)

print("Generating inputs...")
{input_gen_code}

# Move to CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = tuple(t.to(device) for t in inputs)

print("Starting extraction...")
# Setup environment variable for GraphNet workspace
os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = output_dir

# Use the extract API from graph_net
# extract(name, dynamic=True)(model) returns a compiled model
# We need to run it once to trigger compilation and extraction
from graph_net.torch.extractor import extract

compiled_model = extract(name="subgraph", dynamic=True)(model)

print("Running forward pass to trigger extraction...")
with torch.no_grad():
if isinstance(inputs, tuple):
compiled_model(*inputs)
elif isinstance(inputs, dict):
compiled_model(**inputs)
else:
compiled_model(inputs)

print(f"Extraction complete. Results in {{output_dir}}")

if __name__ == "__main__":
main()
"""
return template
81 changes: 81 additions & 0 deletions graph_net/agent/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import logging
from typing import Optional

from .fetcher import HFFetcher
from .analyzer import ConfigAnalyzer
from .coder.template import TemplateCoder
from .extractor import SubprocessExtractor
from .verifier import BasicVerifier

class GraphNetAgent:
def __init__(self, workspace: str, hf_token: Optional[str] = None):
"""
Initialize the GraphNet Agent.

Args:
workspace (str): Directory where models and samples will be stored.
hf_token (str, optional): Hugging Face API token.
"""
self.workspace = os.path.abspath(workspace)
os.makedirs(self.workspace, exist_ok=True)

self.logger = logging.getLogger("GraphNetAgent")
self.logger.setLevel(logging.INFO)

# Initialize components
self.fetcher = HFFetcher(self.workspace, token=hf_token)
self.analyzer = ConfigAnalyzer()
self.coder = TemplateCoder()
self.extractor = SubprocessExtractor(self.workspace)
self.verifier = BasicVerifier()

def process_model(self, model_id: str) -> bool:
"""
Process a single model: Download -> Analyze -> Generate Code -> Extract -> Verify.

Args:
model_id (str): Hugging Face model ID (e.g. 'bert-base-uncased')

Returns:
bool: True if successful, False otherwise.
"""
self.logger.info(f"Starting process for model: {model_id}")

try:
# 1. Download Model
self.logger.info("Step 1: Downloading model...")
model_dir = self.fetcher.download(model_id)
self.logger.info(f"Model downloaded to: {model_dir}")

# 2. Analyze Model Config
self.logger.info("Step 2: Analyzing model config...")
meta_info = self.analyzer.analyze(model_dir)
self.logger.info(f"Analysis result: {meta_info}")

# 3. Generate Running Script
self.logger.info("Step 3: Generating run_model.py...")
script_path = self.coder.generate(model_dir, meta_info)
self.logger.info(f"Script generated at: {script_path}")

# 4. Extract Subgraph
self.logger.info("Step 4: Extracting subgraph...")
output_dir = self.extractor.extract(script_path, model_id)
self.logger.info(f"Extraction output dir: {output_dir}")

# 5. Verify Result
self.logger.info("Step 5: Verifying result...")
is_valid = self.verifier.verify(output_dir)

if is_valid:
self.logger.info(f"SUCCESS: Model {model_id} processed successfully.")
return True
else:
self.logger.error(f"FAILURE: Verification failed for {model_id}.")
return False

except Exception as e:
self.logger.error(f"Error processing {model_id}: {str(e)}")
import traceback
self.logger.error(traceback.format_exc())
return False
Loading
Loading