Skip to content

Latest commit

 

History

History
125 lines (94 loc) · 4.4 KB

README.md

File metadata and controls

125 lines (94 loc) · 4.4 KB

Fisher information embedding for node and graph learning

This repository implements the Fisher information embedding (FIE) described in the following paper

Dexiong Chen *, Paolo Pellizzoni *, and Karsten Borgwardt. Fisher Information Embedding for Node and Graph Learning. ICML 2023.
* Equal contribution

TL;DR: a class of node embeddings with an information geometry interpretation, available with both unsupervised and supervised algorithms for learning the embeddings.

Citation

Please use the following to cite our work:

@inproceedings{chen23fie,
    author = {Dexiong Chen and Paolo Pellizzoni and Karsten Borgwardt},
    title = {Fisher Information Embedding for Node and Graph Learning},
    year = {2023},
    booktitle = {International Conference on Machine Learning~(ICML)},
    series = {Proceedings of Machine Learning Research}
}

A short description of FIE

In this work, we propose a novel attention-based node embedding framework for graphs. Our framework builds upon a hierarchical kernel for multisets of subgraphs around nodes (e.g. neighborhoods) and each kernel leverages the geometry of a smooth statistical manifold to compare pairs of multisets, by “projecting” the multisets onto the manifold. By explicitly computing node embeddings with a manifold of Gaussian mixtures, our method leads to a new attention mechanism for neighborhood aggregation.

An illustration of the Fisher Information Embedding for nodes. (a) Multisets $h(\mathcal{S_G}(\cdot))$ of node features are obtained from the neighborhoods of each node. (b) Multisets are transformed to parametric distributions, e.g. $p_\theta$ and $p_{\theta'}$, via maximum likelihood estimation. (c) The node embeddings are obtained by estimating the parameter of each distribution using the EM algorithm at an anchor distribution $p_{\theta_0}$ as the starting point. The last panel shows a representation of the parametric distribution manifold $\mathcal{M}$ and its tangent space $T_{\theta_0}\mathcal{M}$ at the anchor point $\theta_0$. The points $p_{\theta}$ and $p_{\theta'}$ represent probability distributions on $\mathcal{M}$ and the gray dashed line between them their geodesic distance. The red dashed lines represent the retraction mapping $R_{\theta_0}^{-1}$.

Quickstart

Click to see the example
from torch_geometric import datasets
from torch_geometric.loader import DataLoader

# Construct data loader
dataset = datasets.Planetoid('./datasets/citation', name='Cora', split='public')
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
input_size = dataset.num_node_features

# Build FIE model
model = FIENet(
    input_size,
    num_layers=2,
    hidden_size=16,
    num_mixtures=8,
    pooling=None,
    concat=True
)

# Train model parameters using k-means
model.unsup_train(data_loader)

# Compute node embeddings
X = model.predict(data_loader)

Installation

The dependencies are managed by miniconda. Run the following to install the dependencies

# For CPU only
conda env create -f env.yml
# Or if you have a GPU
conda env create -f env_cuda.yml
# Then activate the environment
conda activate fie

Then, install our fisher_information_embedding package:

pip install -e .

Training models

Please see Table 3 and 4 in our paper to find the search grids for each hyperparameter. Note that we use very minimal hyperparameter tuning in our paper.

Training models on semi-supervised learning tasks using citation networks

  • Unsupervised node embedding mode with logistic classifier:
    python train_citation.py --dataset Cora --hidden-size 512 --num-mixtures 8 --num-layers 4
  • Supervised node embedding mode:
    python train_citation_sup.py --dataset Cora --hidden-size 64 --num-mixtures 8 --num-layers 4

Training models on supervised learning tasks using large OGB datasets

  • Unsupervised node embedding mode with FLAML:
    python train_ogb_node.py --save-memory --dataset ogbn-arxiv --hidden-size 256 --num-mixtures 8 --num-layers 5
  • Supervised node embedding mode:
    python train_ogb_node_sup_ns.py --dataset ogbn-arxiv --hidden-size 256 --num-mixtures 4 --num-layers 3