-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #202 from 1920309095/Hid_Net
[Model] Add HiD-Net
- Loading branch information
Showing
7 changed files
with
384 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import argparse | ||
import os | ||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
# os.environ['TL_BACKEND'] = 'torch' | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR | ||
|
||
import sys | ||
import argparse | ||
import gammagl.transforms as T | ||
import tensorlayerx as tlx | ||
from gammagl.datasets import Planetoid | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
import argparse | ||
import numpy as np | ||
import warnings | ||
import sys | ||
import argparse | ||
from gammagl.models import Hid_net | ||
from gammagl.utils import mask_to_index | ||
warnings.filterwarnings('ignore') | ||
|
||
|
||
class SemiSpvzLoss(WithLoss): | ||
def __init__(self, net, loss_fn): | ||
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) | ||
|
||
def forward(self, data, graph): | ||
logits = self.backbone_network(data['x'], data['edge_index'],num_nodes=data['num_nodes']) | ||
train_logits = tlx.gather(logits, data['train_mask']) | ||
train_y = tlx.gather(data['y'], data['train_mask']) | ||
|
||
loss = self._loss_fn(train_logits,train_y) | ||
|
||
return loss | ||
|
||
def calculate_acc(logits, y, metrics): | ||
""" | ||
Args: | ||
logits: node logits | ||
y: node labels | ||
metrics: tensorlayerx.metrics | ||
Returns: | ||
rst | ||
""" | ||
|
||
metrics.update(logits, y) | ||
rst = metrics.result() | ||
metrics.reset() | ||
return rst | ||
|
||
def main(args): | ||
|
||
# load datasets | ||
if str.lower(args.dataset) not in ['cora','pubmed','citeseer']: | ||
raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
dataset = Planetoid(args.dataset_path, args.dataset) | ||
graph = dataset[0] | ||
edge_weight = tlx.ones(shape=(graph.edge_index.shape[1], 1)) | ||
|
||
# for mindspore, it should be passed into node indices | ||
train_idx = mask_to_index(graph.train_mask) | ||
test_idx = mask_to_index(graph.test_mask) | ||
val_idx = mask_to_index(graph.val_mask) | ||
|
||
data = { | ||
"x": graph.x, | ||
"y": graph.y, | ||
"edge_index": graph.edge_index, | ||
"edge_weight": edge_weight, | ||
"train_mask": train_idx, | ||
"test_mask": test_idx, | ||
"val_mask": val_idx, | ||
"num_nodes": graph.num_nodes, | ||
} | ||
|
||
model = Hid_net(in_feats=dataset.num_features, | ||
hidden_dim=args.hidden_dim, | ||
n_classes=dataset.num_classes, | ||
num_layers=args.num_layers, | ||
alpha=args.alpha, | ||
beta=args.beta, | ||
gamma=args.gamma, | ||
add_bias=args.add_bias, | ||
normalize=args.normalize, | ||
drop_rate=args.drop_rate, | ||
sigma1=args.sigma1, | ||
sigma2=args.sigma2, | ||
name="Hid_Net") | ||
|
||
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay) | ||
metrics = tlx.metrics.Accuracy() | ||
train_weights = model.trainable_weights | ||
|
||
loss_func = SemiSpvzLoss(model, tlx.losses.softmax_cross_entropy_with_logits) | ||
train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
|
||
best_val_acc=0 | ||
for epoch in range(args.n_epoch): | ||
model.set_train() | ||
train_loss = train_one_step(data, graph.y) | ||
|
||
model.set_eval() | ||
logits = model(data['x'], data['edge_index'], data['edge_weight'], num_nodes=data['num_nodes']) | ||
|
||
val_preds=tlx.gather(logits, data['val_mask']) | ||
val_y=tlx.gather(data['y'], data['val_mask']) | ||
val_acc=calculate_acc(val_preds,val_y,metrics) | ||
|
||
# save best model on evaluation set | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
model.save_weights(args.best_model_path+model.name+".npz", format='npz_dict') | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch+1)\ | ||
+ " train loss: {:.4f}".format(train_loss.item())\ | ||
+ " val acc: {:.4f}".format(val_acc)) | ||
|
||
model.load_weights(args.best_model_path+model.name+".npz", format='npz_dict') | ||
model.set_eval() | ||
logits = model(data['x'], data['edge_index'], data['edge_weight'], num_nodes=data['num_nodes']) | ||
test_preds = tlx.gather(logits, data['test_mask']) | ||
test_y = tlx.gather(data['y'], data['test_mask']) | ||
test_acc = calculate_acc(test_preds, test_y, metrics) | ||
print("Test acc: {:.4f}".format(test_acc)) | ||
|
||
if __name__=='__main__': | ||
# parameters settings | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") | ||
parser.add_argument('--n_epoch', type=int, default=150, help='the num of epoch') | ||
parser.add_argument('--lr', type=float, default=0.01, help='learning rate') | ||
parser.add_argument('--weight_decay', type=float, default=0.00, help='weight decay (L2 loss on parameters)') | ||
parser.add_argument('--hidden_dim', type=int, default=128, help='hidden size') | ||
parser.add_argument('--drop_rate', type=float, default=0.55, help='dropout rate') | ||
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'citeseer', 'pubmed']) | ||
parser.add_argument('--dataset_path', type=str, default='./data', help='path to save dataset') | ||
parser.add_argument('--num_layers', type=int, default=10, help='num_layers') | ||
parser.add_argument('--alpha', type=float, default=0.1, help='tolerance to stop EM algorithm') | ||
parser.add_argument('--beta', type=float, default=0.9, help='tolerance to stop EM algorithm') | ||
parser.add_argument('--gamma', type=float, default=0.3, help='tolerance to stop EM algorithm') | ||
parser.add_argument('--sigma1', type=float, default=0.5, help='tolerance to stop EM algorithm') | ||
parser.add_argument('--sigma2', type=float, default=0.5, help='tolerance to stop EM algorithm') | ||
parser.add_argument('--add_bias', type=bool, default=True, help='if tune') | ||
parser.add_argument('--normalize', type=bool, default=True, help='if tune') | ||
parser.add_argument('--gpu', default='-1', type=int, help='-1 means cpu') | ||
args = parser.parse_args() | ||
|
||
if args.gpu >= 0: | ||
tlx.set_device("GPU", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
Representation Learning on Graphs with Jumping Knowledge Networks (JK-Net) | ||
============ | ||
|
||
- Paper link: [https://arxiv.org/abs/2312.08616](https://arxiv.org/abs/2312.08616) | ||
- Author's code repo: [https://github.com/BUPT-GAMMA/HiD-Net](https://github.com/BUPT-GAMMA/HiD-Net). Note that the original code is | ||
implemented with Torch for the paper. | ||
|
||
|
||
How to run | ||
---------- | ||
Run with following (available dataset: "cora", "citeseer", "pubmed") | ||
```bash | ||
python hid_trainer.py --dataset cora | ||
``` | ||
|
||
|
||
Results | ||
------- | ||
```bash | ||
TL_BACKEND="mindspore" python hid_trainer.py --dataset cora --alpha 0.1 --beta 0.9 --gamma 0.3 --k 10 --hidden 128 --lr 0.01 --weight_decay 0 --dropout 0.55 | ||
TL_BACKEND="mindspore" python hid_trainer.py --dataset citeseer --alpha 0.1 --beta 0.9 --gamma 0.2 --k 10 --hidden 64 --lr 0.005 --weight_decay 0.05 --dropout 0.5 | ||
TL_BACKEND="mindspore" python hid_trainer.py --dataset pubmed --alpha 0.08 --beta 0.92 --gamma 0.3 --k 8 --hidden 32 --lr 0.02 --weight_decay 0.0005 --dropout 0.5 | ||
TL_BACKEND="paddle" python hid_trainer.py --dataset cora --alpha 0.1 --beta 0.9 --gamma 0.3 --k 10 --hidden 128 --lr 0.01 --weight_decay 0 --dropout 0.55 | ||
TL_BACKEND="paddle" python hid_trainer.py --dataset citeseer --alpha 0.1 --beta 0.9 --gamma 0.2 --k 10 --hidden 64 --lr 0.005 --weight_decay 0.05 --dropout 0.5 | ||
TL_BACKEND="paddle" python hid_trainer.py --dataset pubmed --alpha 0.08 --beta 0.92 --gamma 0.3 --k 8 --hidden 32 --lr 0.02 --weight_decay 0.0005 --dropout 0.5 | ||
TL_BACKEND="tensorflow" python hid_trainer.py --dataset cora --alpha 0.1 --beta 0.9 --gamma 0.3 --k 10 --hidden 128 --lr 0.01 --weight_decay 0 --dropout 0.55 | ||
TL_BACKEND="tensorflow" python hid_trainer.py --dataset citeseer --alpha 0.1 --beta 0.9 --gamma 0.2 --k 10 --hidden 64 --lr 0.005 --weight_decay 0.05 --dropout 0.5 | ||
TL_BACKEND="tensorflow" python hid_trainer.py --dataset pubmed --alpha 0.08 --beta 0.92 --gamma 0.3 --k 8 --hidden 32 --lr 0.02 --weight_decay 0.0005 --dropout 0.5 | ||
TL_BACKEND="torch" python hid_trainer.py --dataset cora --alpha 0.1 --beta 0.9 --gamma 0.3 --k 10 --hidden 128 --lr 0.01 --weight_decay 0 --dropout 0.55 | ||
TL_BACKEND="torch" python hid_trainer.py --dataset citeseer --alpha 0.1 --beta 0.9 --gamma 0.2 --k 10 --hidden 64 --lr 0.005 --weight_decay 0.05 --dropout 0.5 | ||
TL_BACKEND="torch" python hid_trainer.py --dataset pubmed --alpha 0.08 --beta 0.92 --gamma 0.3 --k 8 --hidden 32 --lr 0.02 --weight_decay 0.0005 --dropout 0.5 | ||
``` | ||
| Dataset | Paper | Our(ms) | Our(pd) | Our(tf) | Our(th) | | ||
| ---- | ---- | --- | ---- | ---- | ---- | | ||
| cora | 0.840(±0.6) | 0.8078(±0.0049) | 0.8274(±0.0190) | 0.8218(±0.0024) | 0.8138(±0.00024) | | ||
| citeseer | 0.732(±0.2) | 0.7138(±0.0019) | 0.7134(±0.0012) | 0.7140(±0.0022)| 0.7134(±0.0022)| | ||
| pubmed | 0.811(±0.1) | 0.7996(±0.0030) | 0.7910(±0.0044) | 0.8026(±0.0034) | 0.7938(±0.0151) | 0.7920(±0.0097)| |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import tensorlayerx as tlx | ||
from gammagl.layers.conv import MessagePassing | ||
import tensorlayerx as tlx | ||
from gammagl.mpops import * | ||
from gammagl.utils.num_nodes import maybe_num_nodes | ||
|
||
|
||
def cal_g_gradient(edge_index, x, edge_weight=None, sigma1=0.5, sigma2=0.5, num_nodes=None, | ||
dtype=None): | ||
row, col = edge_index[0], edge_index[1] | ||
ones = tlx.ones([edge_index[0].shape[0]]) | ||
if dtype is not None: | ||
ones = tlx.cast(ones, dtype) | ||
if num_nodes is None: | ||
num_nodes = int(1 + tlx.reduce_max(edge_index[0])) | ||
deg = unsorted_segment_sum(ones, col, num_segments=num_nodes) | ||
deg_inv = tlx.pow(deg+1e-8, -1) | ||
deg_in_row = tlx.reshape(tlx.gather(deg_inv, row), (-1,1)) | ||
x_row = tlx.gather(x, row) | ||
x_col = tlx.gather(x, col) | ||
gra = deg_in_row * (x_col - x_row) | ||
avg_gra = unsorted_segment_sum(gra, row, num_segments=x.shape[0]) | ||
dx = x_row - x_col | ||
norms_dx = tlx.reduce_sum(tlx.square(dx), axis=1) | ||
norms_dx = tlx.sqrt(norms_dx) | ||
s = norms_dx | ||
s = tlx.exp(- (s * s) / (2 * sigma1 * sigma2)) | ||
r = unsorted_segment_sum(tlx.reshape(s,(-1,1)), row,num_segments=x.shape[0]) | ||
r_row = tlx.gather(r, row) | ||
coe = tlx.reshape(s, (-1,1)) / (r_row + 1e-6) | ||
avg_gra_row = tlx.gather(avg_gra, row) | ||
result = unsorted_segment_sum(avg_gra_row * coe, col, num_segments=x.shape[0]) | ||
return result | ||
|
||
class Hid_conv(MessagePassing): | ||
r'''The proposed high-order graph diffusion equation is given by: | ||
.. math:: | ||
\frac{\partial x(t)_i}{\partial t} = | ||
\alpha(x(0)_i - x(t)_i) + | ||
\beta \text{div}(f(\nabla x(t)_{ij})) + | ||
\gamma(\nabla x(t)_j), | ||
where \( \alpha \), \( \beta \), and \( \gamma \) are parameters of the model. | ||
This equation integrates the high-order diffusion process by considering the influence of both first-order and second-order neighbors in the graph. | ||
The iteration step based on this equation is formulated as: | ||
.. math:: | ||
x(t+\Delta t)_i = \\alpha \Delta t x(0)_i + | ||
(1 - \alpha \Delta t)x(t)_i + \beta \Delta t \text{div}(f(\nabla x(t)_i)) + | ||
\beta \gamma \Delta t \text{div}((\nabla x(t))_j), | ||
which represents the diffusion-based message passing scheme (DMP) of the High-order Graph Diffusion Network (HiD-Net). | ||
This scheme leverages the information from two-hop neighbors, offering two main advantages: | ||
it captures the local environment around a node, enhancing the robustness of the model against abnormal features within one-hop neighbors; | ||
and it utilizes the monophily property of two-hop neighbors, which provides a stronger correlation with labels | ||
and thus enables better predictions even in the presence of heterophily within one-hop neighbors. | ||
Parameters | ||
---------- | ||
alpha: float | ||
beta: float | ||
gamma: float | ||
sigma1: float | ||
sigma2: float | ||
''' | ||
def __init__(self, | ||
alpha, | ||
beta, | ||
gamma, | ||
sigma1, | ||
sigma2): | ||
super().__init__() | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.gamma = gamma | ||
self.sigma1 = sigma1 | ||
self.sigma2 = sigma2 | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
|
||
self._cached_edge_index = None | ||
self._cached_adj_t = None | ||
|
||
def forward(self, x, origin, edge_index, edge_weight, ei_no_loops, ew_no_loops, num_nodes=None): | ||
if num_nodes == None: | ||
num_nodes = maybe_num_nodes(edge_index, num_nodes) | ||
|
||
ew2 = tlx.reshape(ew_no_loops, (-1, 1)) | ||
|
||
g = cal_g_gradient(edge_index=ei_no_loops, x=x, edge_weight=ew2, sigma1=self.sigma1, sigma2=self.sigma2, dtype=None) | ||
|
||
Ax = self.propagate(x, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) | ||
Gx = self.propagate(g, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) | ||
|
||
x = self.alpha * origin + (1 - self.alpha - self.beta) * x \ | ||
+ self.beta * Ax \ | ||
+ self.beta * self.gamma * Gx | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.