Skip to content

Commit

Permalink
[Model] GCIL (#226)
Browse files Browse the repository at this point in the history
* 更新gcil模型与训练

* 更新

* 更新

* 删除多余文件

* fix bugs

---------

Co-authored-by: jxy <865526875@qq.com>
  • Loading branch information
cjx2004 and xy-Ji authored Jan 19, 2025
1 parent 1ab7720 commit f21a720
Show file tree
Hide file tree
Showing 23 changed files with 707 additions and 1 deletion.
35 changes: 35 additions & 0 deletions examples/gcil/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Graph Contrastive Invariant Learning from the Causal Perspective (GCIL)

- Paper link: [https://arxiv.org/pdf/2401.12564v2](https://arxiv.org/pdf/2401.12564v2)
- Author's code repo: [https://github.com/BUPT-GAMMA/GCIL](https://github.com/BUPT-GAMMA/GCIL). Note that the original code is
implemented with Tensorflow for the paper.

# Dataset Statics

| Dataset | # Nodes | # Edges | # Classes |
| ------- | ------- | ------- | --------- |
| Cora | 2,708 | 10,556 | 7 |
| Pubmed | 19,717 | 88,651 | 3 |

Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid).

Results
-------

```bash
# available dataset: "cora", "pubmed"
TL_BACKEND="torch" python gcil_trainer.py cora
TL_BACKEND="torch" python gcil_trainer.py pubmed
```

Ma-F1:
| Dataset | Paper | Our(th) |
| ------- | -------- | ---------- |
| cora | 83.8±0.5 | 45.19±0.22 |
| pubmed | 81.5±0.5 | 46.30±0.02 |

Mi-F1
| Dataset | Paper | Our(th) |
| ------- | -------- | ---------- |
| cora | 84.4±0.7 | 49.71±0.22 |
| pubmed | 81.6±0.7 | 53.77±0.01 |
50 changes: 50 additions & 0 deletions examples/gcil/aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch as th
import numpy as np
from gammagl.data import Graph


def random_aug(graph, attr, diag_attr, x, feat_drop_rate, edge_mask_rate, device):
n_node = graph.num_nodes if hasattr(graph, 'num_nodes') else graph.x.shape[0]

edge_mask = mask_edge(graph, edge_mask_rate)

feat = drop_feature(x, feat_drop_rate)

edge_index = graph.edge_index[:, edge_mask].long()

edge_weight = attr[edge_mask] if attr is not None else None

if isinstance(attr, np.ndarray):
attr_cpu = attr.cpu().numpy() if attr.is_cuda else attr.numpy()
diag_attr_cpu = diag_attr.cpu().numpy() if diag_attr.is_cuda else diag_attr.numpy()
attr = np.concatenate([attr_cpu[edge_mask], diag_attr_cpu], axis=0)

new_graph = Graph(x=feat, edge_index=edge_index)

new_graph.x = new_graph.x.to(device)
new_graph.edge_index = new_graph.edge_index.to(device)

if edge_weight is not None:
edge_weight = th.tensor(edge_weight, dtype=th.float32).to(device) # 确保 edge_weight 是 FloatTensor

return new_graph, edge_weight, feat



def drop_feature(x, drop_prob):
drop_mask = th.empty(
(x.size(1),),
dtype=th.float32,
device=x.device).uniform_(0, 1) < drop_prob
x = x.clone()
x[:, drop_mask] = 0

return x


def mask_edge(graph, edge_mask_rate):
E = graph.edge_index.shape[1]

mask = np.random.rand(E) > edge_mask_rate

return mask
Binary file added examples/gcil/dataset/cora/0.01_1_0.npz
Binary file not shown.
Binary file added examples/gcil/dataset/cora/0.01_1_1.npz
Binary file not shown.
Binary file added examples/gcil/dataset/cora/0.01_1_2.npz
Binary file not shown.
Binary file added examples/gcil/dataset/cora/0.01_1_3.npz
Binary file not shown.
Binary file added examples/gcil/dataset/cora/0.01_1_4.npz
Binary file not shown.
Binary file added examples/gcil/dataset/cora/0.01_1_5.npz
Binary file not shown.
Binary file added examples/gcil/dataset/cora/0.01_1_6.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_0.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_1.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_2.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_3.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_4.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_5.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_6.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_7.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_8.npz
Binary file not shown.
Binary file added examples/gcil/dataset/pubmed/0.01_1_9.npz
Binary file not shown.
Loading

0 comments on commit f21a720

Please sign in to comment.