This repository implements the Fisher information embedding (FIE) described in the following paper
Dexiong Chen *, Paolo Pellizzoni *, and Karsten Borgwardt. Fisher Information Embedding for Node and Graph Learning. ICML 2023.
* Equal contribution
TL;DR: a class of node embeddings with an information geometry interpretation, available with both unsupervised and supervised algorithms for learning the embeddings.
Please use the following to cite our work:
@inproceedings{chen23fie,
author = {Dexiong Chen and Paolo Pellizzoni and Karsten Borgwardt},
title = {Fisher Information Embedding for Node and Graph Learning},
year = {2023},
booktitle = {International Conference on Machine Learning~(ICML)},
series = {Proceedings of Machine Learning Research}
}
In this work, we propose a novel attention-based node embedding framework for graphs. Our framework builds upon a hierarchical kernel for multisets of subgraphs around nodes (e.g. neighborhoods) and each kernel leverages the geometry of a smooth statistical manifold to compare pairs of multisets, by “projecting” the multisets onto the manifold. By explicitly computing node embeddings with a manifold of Gaussian mixtures, our method leads to a new attention mechanism for neighborhood aggregation.
Click to see the example
from torch_geometric import datasets
from torch_geometric.loader import DataLoader
# Construct data loader
dataset = datasets.Planetoid('./datasets/citation', name='Cora', split='public')
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
input_size = dataset.num_node_features
# Build FIE model
model = FIENet(
input_size,
num_layers=2,
hidden_size=16,
num_mixtures=8,
pooling=None,
concat=True
)
# Train model parameters using k-means
model.unsup_train(data_loader)
# Compute node embeddings
X = model.predict(data_loader)
The dependencies are managed by miniconda. Run the following to install the dependencies
# For CPU only
conda env create -f env.yml
# Or if you have a GPU
conda env create -f env_cuda.yml
# Then activate the environment
conda activate fie
Then, install our fisher_information_embedding
package:
pip install -e .
Please see Table 3 and 4 in our paper to find the search grids for each hyperparameter. Note that we use very minimal hyperparameter tuning in our paper.
- Unsupervised node embedding mode with logistic classifier:
python train_citation.py --dataset Cora --hidden-size 512 --num-mixtures 8 --num-layers 4
- Supervised node embedding mode:
python train_citation_sup.py --dataset Cora --hidden-size 64 --num-mixtures 8 --num-layers 4
- Unsupervised node embedding mode with FLAML:
python train_ogb_node.py --save-memory --dataset ogbn-arxiv --hidden-size 256 --num-mixtures 8 --num-layers 5
- Supervised node embedding mode:
python train_ogb_node_sup_ns.py --dataset ogbn-arxiv --hidden-size 256 --num-mixtures 4 --num-layers 3