This project explores link prediction using Graph Neural Networks (GNNs), focusing on the Cora citation dataset. The task involves predicting missing or potential future links (edges) in the graph. The implementation compares three state-of-the-art GNN architectures:
- Graph Convolutional Networks (GCN)
- GraphSAGE
- Graph Attention Networks (GAT)
This repository is inspired by the research paper "Predicting the Future of AI with AI: High-quality Link Prediction in an Exponentially Growing Knowledge Network" and adapts its techniques to a static graph scenario.
- Dataset: The Cora citation dataset contains:
- Nodes: 2,708 scientific papers.
- Edges: 5,429 citation links.
- Node Features: Bag-of-words representation (1,433 features per node).
- Classes: Seven categories of research topics.
- Task: Predict whether an edge exists between two nodes (link prediction).
- Key Metrics:
- AUC (Area Under the ROC Curve): Evaluates the classification performance.
- AP (Average Precision): Measures the precision-recall balance.
- GCN emerged as the best-performing model in the second notebook, achieving the highest Test AUC (0.9268) and Test AP (0.9324)
GNNs extend traditional neural networks to graph-structured data by learning node embeddings that capture both local and global graph structures. They work by propagating and aggregating information between neighboring nodes.
The link prediction task involves estimating the probability ( P(u, v) ) that an edge exists between two nodes ( u ) and ( v ). The prediction is based on the embeddings learned by the GNN.
GCNs perform spectral convolutions to capture local neighborhood information.
GraphSAGE introduces inductive learning by sampling fixed-size neighborhoods.
Common aggregators include:
- Mean pooling
- Max pooling
- LSTM pooling
GAT uses attention mechanisms to dynamically assign importance to neighbors.
- Splitting Edges:
- Training: 85%
- Validation: 10%
- Testing: 5%
- Negative Sampling:
- Negative edges are sampled to balance the dataset for binary classification.
Model | Test AUC | Test AP |
---|---|---|
GCN | 0.9268 | 0.9324 |
GraphSAGE | 0.7394 | 0.7423 |
GAT | 0.5803 | 0.5754 |
- GCN emerged as the best-performing model in the second notebook, achieving the highest Test AUC (0.9268) and Test AP (0.9324).
This project demonstrates the effectiveness of GNNs in link prediction tasks. The results highlight the importance of advanced architectures like GAT in capturing complex relationships in graph data. This implementation shows:
- High Accuracy: Achieved a Test AUC of 0.9550 with GAT.
- Scalability: Models like GraphSAGE can handle dynamic, inductive tasks.
- Applications: These techniques are broadly applicable to social networks, recommendation systems, and biological networks.
- Extend the models to handle dynamic graphs.
- Incorporate edge attributes for richer representations.
- Apply these techniques to larger, real-world datasets.
- Clone this repository:
git clone <repository-url>
- Install dependencies:
pip install -r requirements.txt
- Run the notebook:
jupyter notebook linkpredict2.ipynb
- Kipf, T. N., & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks.
- Hamilton, W., Ying, Z., & Leskovec, J. (2017). Inductive Representation Learning on Large Graphs.
- Veličković, P., et al. (2018). Graph Attention Networks.
- Predicting the Future of AI with AI: High-quality Link Prediction in an Exponentially Growing Knowledge Network.