-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* sgformer,actor,deezer * model:sgformer,datasets:actor,deezer_europe --------- Co-authored-by: Xingyuan Ji <82867498+xy-Ji@users.noreply.github.com>
- Loading branch information
Showing
8 changed files
with
480 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# !/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
""" | ||
@File : sgformer_trainer.py | ||
@Time : 2024/12/30 12:57:55 | ||
@Author : Cui Shanyuan | ||
""" | ||
|
||
import os | ||
import argparse | ||
import tensorlayerx as tlx | ||
import numpy as np | ||
from gammagl.datasets import Planetoid, WikipediaNetwork,Actor,DeezerEurope | ||
from gammagl.models import SGFormerModel | ||
from gammagl.utils import add_self_loops, mask_to_index | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
|
||
class SemiSpvzLoss(WithLoss): | ||
def __init__(self, net, loss_fn): | ||
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) | ||
|
||
def forward(self, data, y): | ||
logits = self.backbone_network(data['x'], data['edge_index'], None, data['num_nodes']) | ||
train_logits = tlx.gather(logits, data['train_idx']) | ||
train_y = tlx.gather(data['y'], data['train_idx']) | ||
loss = self._loss_fn(train_logits, train_y) | ||
return loss | ||
|
||
def calculate_acc(logits, y, metrics): | ||
metrics.update(logits, y) | ||
rst = metrics.result() | ||
metrics.reset() | ||
return rst | ||
|
||
def main(args): | ||
|
||
if str.lower(args.dataset) in ['cora', 'pubmed', 'citeseer']: | ||
dataset = Planetoid(args.dataset_path, args.dataset) | ||
graph = dataset[0] | ||
train_idx = mask_to_index(graph.train_mask) | ||
test_idx = mask_to_index(graph.test_mask) | ||
val_idx = mask_to_index(graph.val_mask) | ||
elif str.lower(args.dataset) in ['chameleon', 'squirrel']: | ||
dataset = WikipediaNetwork(args.dataset_path, args.dataset) | ||
graph = dataset[0] | ||
|
||
split_idx = 0 | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
split_path = os.path.join(current_dir, args.dataset, 'geom_gcn', 'raw', | ||
f'{args.dataset}_split_0.6_0.2_{split_idx}.npz') | ||
print(f"Looking for split file at: {split_path}") | ||
splits_file = np.load(split_path) | ||
train_mask = splits_file['train_mask'] | ||
val_mask = splits_file['val_mask'] | ||
test_mask = splits_file['test_mask'] | ||
train_idx = np.where(train_mask)[0] | ||
val_idx = np.where(val_mask)[0] | ||
test_idx = np.where(test_mask)[0] | ||
elif str.lower(args.dataset) == 'actor': | ||
dataset = Actor(args.dataset_path) | ||
graph = dataset[0] | ||
|
||
split_idx = args.split_idx | ||
train_idx = mask_to_index(graph.train_mask[:, split_idx]) | ||
val_idx = mask_to_index(graph.val_mask[:, split_idx]) | ||
test_idx = mask_to_index(graph.test_mask[:, split_idx]) | ||
elif str.lower(args.dataset) == 'deezer': | ||
dataset = DeezerEurope(args.dataset_path) | ||
graph = dataset[0] | ||
|
||
num_nodes = graph.num_nodes | ||
train_ratio = 0.6 | ||
val_ratio = 0.2 | ||
|
||
|
||
indices = np.random.permutation(num_nodes) | ||
train_size = int(num_nodes * train_ratio) | ||
val_size = int(num_nodes * val_ratio) | ||
|
||
train_idx = indices[:train_size] | ||
val_idx = indices[train_size:train_size + val_size] | ||
test_idx = indices[train_size + val_size:] | ||
else: | ||
raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
|
||
edge_index, _ = add_self_loops(graph.edge_index, num_nodes=graph.num_nodes) | ||
|
||
|
||
net = SGFormerModel(feature_dim=dataset.num_node_features, | ||
hidden_dim=args.hidden_dim, | ||
num_class=dataset.num_classes, | ||
trans_num_layers=args.trans_num_layers, | ||
trans_num_heads=args.trans_num_heads, | ||
trans_dropout=args.trans_dropout, | ||
gnn_num_layers=args.gnn_num_layers, | ||
gnn_dropout=args.gnn_dropout, | ||
graph_weight=args.graph_weight, | ||
name="SGFormer") | ||
|
||
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) | ||
metrics = tlx.metrics.Accuracy() | ||
train_weights = net.trainable_weights | ||
|
||
loss_func = SemiSpvzLoss(net, tlx.losses.softmax_cross_entropy_with_logits) | ||
train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
|
||
|
||
data = { | ||
"x": tlx.convert_to_tensor(graph.x), | ||
"y": tlx.convert_to_tensor(graph.y), | ||
"edge_index": tlx.convert_to_tensor(edge_index), | ||
"train_idx": tlx.convert_to_tensor(train_idx), | ||
"test_idx": tlx.convert_to_tensor(test_idx), | ||
"val_idx": tlx.convert_to_tensor(val_idx), | ||
"num_nodes": graph.num_nodes, | ||
} | ||
|
||
best_val_acc = 0 | ||
for epoch in range(args.n_epoch): | ||
net.set_train() | ||
train_loss = train_one_step(data, graph.y) | ||
net.set_eval() | ||
logits = net(data['x'], data['edge_index'], None, data['num_nodes']) | ||
val_logits = tlx.gather(logits, data['val_idx']) | ||
val_y = tlx.gather(data['y'], data['val_idx']) | ||
val_acc = calculate_acc(val_logits, val_y, metrics) | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch+1)\ | ||
+ " train loss: {:.4f}".format(train_loss.item())\ | ||
+ " val acc: {:.4f}".format(val_acc)) | ||
|
||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
net.save_weights(args.best_model_path+net.name+".npz", format='npz_dict') | ||
|
||
net.load_weights(args.best_model_path+net.name+".npz", format='npz_dict') | ||
net.set_eval() | ||
logits = net(data['x'], data['edge_index'], None, data['num_nodes']) | ||
test_logits = tlx.gather(logits, data['test_idx']) | ||
test_y = tlx.gather(data['y'], data['test_idx']) | ||
test_acc = calculate_acc(test_logits, test_y, metrics) | ||
print("Test acc: {:.4f}".format(test_acc)) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--lr', type=float, default=0.01, help='learning rate') | ||
parser.add_argument('--n_epoch', type=int, default=200, help='number of epochs') | ||
parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension') | ||
parser.add_argument('--trans_num_layers', type=int, default=1, help='number of transformer layers') | ||
parser.add_argument('--trans_num_heads', type=int, default=1, help='number of attention heads') | ||
parser.add_argument('--trans_dropout', type=float, default=0.5, help='transformer dropout rate') | ||
parser.add_argument('--gnn_num_layers', type=int, default=2, help='number of GNN layers') | ||
parser.add_argument('--gnn_dropout', type=float, default=0.5, help='GNN dropout rate') | ||
parser.add_argument('--graph_weight', type=float, default=0.8, help='weight for GNN branch') | ||
parser.add_argument('--l2_coef', type=float, default=5e-4, help='l2 loss coefficient') | ||
parser.add_argument('--dataset', type=str, default='cora', | ||
choices=['cora', 'pubmed', 'citeseer', 'chameleon', | ||
'squirrel', 'actor', 'deezer'], | ||
help='dataset name') | ||
parser.add_argument('--dataset_path', type=str, default=r'', help='path to save dataset') | ||
parser.add_argument('--best_model_path', type=str, default=r'./', help='path to save best model') | ||
parser.add_argument('--gpu', type=int, default=0) | ||
parser.add_argument('--split_idx', type=int, default=0, | ||
help='split index for actor dataset') | ||
parser.add_argument('--seed', type=int, default=42, help='Random seed') | ||
args = parser.parse_args() | ||
|
||
|
||
np.random.seed(args.seed) | ||
|
||
if args.gpu >= 0: | ||
tlx.set_device("cuda", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import Callable, List, Optional | ||
import numpy as np | ||
import tensorlayerx as tlx | ||
from gammagl.data import InMemoryDataset, download_url, Graph | ||
from gammagl.utils import coalesce | ||
|
||
class Actor(InMemoryDataset): | ||
r"""The actor-only induced subgraph of the film-director-actor-writer | ||
network used in the | ||
`"Geom-GCN: Geometric Graph Convolutional Networks" | ||
<https://openreview.net/forum?id=S1e2agrFvS>`_ paper. | ||
Each node corresponds to an actor, and the edge between two nodes denotes | ||
co-occurrence on the same Wikipedia page. | ||
Node features correspond to some keywords in the Wikipedia pages. | ||
The task is to classify the nodes into five categories in term of words of | ||
actor's Wikipedia. | ||
Parameters | ||
---------- | ||
root: str | ||
Root directory where the dataset should be saved. | ||
transform: callable, optional | ||
A function/transform that takes in a | ||
:obj:`gammagl.data.Graph` object and returns a transformed | ||
version. The data object will be transformed before every access. | ||
(default: :obj:`None`) | ||
pre_transform: callable, optional | ||
A function/transform that takes in | ||
an :obj:`gammagl.data.Graph` object and returns a | ||
transformed version. The data object will be transformed before | ||
being saved to disk. (default: :obj:`None`) | ||
force_reload (bool, optional): Whether to re-process the dataset. | ||
(default: :obj:`False`) | ||
**STATS:** | ||
.. list-table:: | ||
:widths: 10 10 10 10 | ||
:header-rows: 1 | ||
* - #nodes | ||
- #edges | ||
- #features | ||
- #classes | ||
* - 7,600 | ||
- 30,019 | ||
- 932 | ||
- 5 | ||
""" | ||
|
||
url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' | ||
|
||
def __init__(self, root: str, transform: Optional[Callable] = None, | ||
pre_transform: Optional[Callable] = None, force_reload: bool = False) -> None: | ||
super().__init__(root, transform, pre_transform, force_reload=force_reload) | ||
self.data, self.slices = self.load_data(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self) -> List[str]: | ||
return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [f'film_split_0.6_0.2_{i}.npz' for i in range(10)] | ||
|
||
@property | ||
def processed_file_names(self) -> str: | ||
return tlx.BACKEND + '_data.pt' | ||
|
||
def download(self) -> None: | ||
for f in self.raw_file_names[:2]: | ||
download_url(f'{self.url}/new_data/film/{f}', self.raw_dir) | ||
for f in self.raw_file_names[2:]: | ||
download_url(f'{self.url}/splits/{f}', self.raw_dir) | ||
|
||
def process(self) -> None: | ||
with open(self.raw_paths[0], 'r') as f: | ||
node_data = [x.split('\t') for x in f.read().split('\n')[1:-1]] | ||
|
||
rows, cols = [], [] | ||
for n_id, line, _ in node_data: | ||
indices = [int(x) for x in line.split(',')] | ||
rows += [int(n_id)] * len(indices) | ||
cols += indices | ||
row = np.array(rows, dtype=np.int64) | ||
col = np.array(cols, dtype=np.int64) | ||
|
||
num_nodes = int(row.max()) + 1 | ||
num_features = int(col.max()) + 1 | ||
x = np.zeros((num_nodes, num_features), dtype=np.float32) | ||
x[row, col] = 1.0 | ||
|
||
y = np.zeros(len(node_data), dtype=np.int64) | ||
for n_id, _, label in node_data: | ||
y[int(n_id)] = int(label) | ||
|
||
with open(self.raw_paths[1], 'r') as f: | ||
edge_data = f.read().split('\n')[1:-1] | ||
edge_indices = [[int(v) for v in r.split('\t')] for r in edge_data] | ||
edge_index = np.array(edge_indices, dtype=np.int64).T | ||
edge_index = coalesce(edge_index) # 保留self loop | ||
|
||
train_masks, val_masks, test_masks = [], [], [] | ||
for path in self.raw_paths[2:]: | ||
tmp = np.load(path) | ||
train_masks.append(tmp['train_mask'].astype(np.bool_)) | ||
val_masks.append(tmp['val_mask'].astype(np.bool_)) | ||
test_masks.append(tmp['test_mask'].astype(np.bool_)) | ||
train_mask = np.stack(train_masks, axis=1) | ||
val_mask = np.stack(val_masks, axis=1) | ||
test_mask = np.stack(test_masks, axis=1) | ||
|
||
data = Graph(x=tlx.convert_to_tensor(x), edge_index=tlx.convert_to_tensor(edge_index), | ||
y=tlx.convert_to_tensor(y), train_mask=tlx.convert_to_tensor(train_mask), | ||
val_mask=tlx.convert_to_tensor(val_mask), test_mask=tlx.convert_to_tensor(test_mask)) | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
self.save_data(self.collate([data]), self.processed_paths[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os.path as osp | ||
from typing import Callable, Optional | ||
|
||
import tensorlayerx as tlx | ||
from gammagl.data import InMemoryDataset, download_url, Graph | ||
import numpy as np | ||
|
||
class DeezerEurope(InMemoryDataset): | ||
url = 'https://graphmining.ai/datasets/ptg/deezer_europe.npz' | ||
|
||
def __init__(self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None) -> None: | ||
super().__init__(root, transform, pre_transform, pre_filter) | ||
self.data, self.slices = self.load_data(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self) -> str: | ||
return 'deezer_europe.npz' | ||
|
||
@property | ||
def processed_file_names(self) -> str: | ||
return tlx.BACKEND + '_data.npy' | ||
|
||
def download(self) -> None: | ||
download_url(self.url, self.raw_dir) | ||
|
||
def process(self) -> None: | ||
data = np.load(self.raw_paths[0], allow_pickle=True) | ||
x = data['features'].astype(np.float32) | ||
y = data['target'].astype(np.float32) | ||
edge_index = data['edges'].astype(np.int64) | ||
edge_index = edge_index.T | ||
|
||
data = Graph(x=x, y=y, edge_index=edge_index) | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
self.save_data(self.collate([data]), self.processed_paths[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.