This repository contains the code used for the experiments of:
"Understanding Pooling in Graph Neural Networks"
D. Grattarola, D. Zambon, F. M. Bianchi, C. Alippi
https://arxiv.org/abs/2110.05292
The dependencies of the project are listed in requirements.txt. You can install them with:
pip install -r requirements.txt
The code to run our experiments is in the following folders:
autoencoder/
spectral_similarity/
graph_classification/
Each folder has a script called run_all.sh
that will reproduce the results reported in the paper.
To generate the plots and tables from the paper, you can use the plots.py
, plots_datasets.py
, or tables.py
scripts in each folder.
To run experiments for an individual pooling operator, you can use the run_[OPERATOR NAME].py
scripts in each folder.
The pooling operators that we used for the experiments are in layers/
(trainable) and modules/
(non-trainable).
The GNN architectures used in the experiments are in models/
.
The core of this repository is the SRCPool
class that implements a general
interface to create SRC pooling layers with the Keras API.
Our implementation of MinCutPool, DiffPool, LaPool, Top-K, and SAGPool using the
SRCPool
class can be found in src/layers
.
SRC layers have the following structure
where
By extending this class, it is possible to create any pooling layer in the SRC framework.
Input
X
: Tensor of shape([batch], N, F)
representing node features;A
: Tensor or SparseTensor of shape([batch], N, N)
representing the adjacency matrix;I
: (optional) Tensor of integers with shape(N, )
representing the batch index;
Output
X_pool
: Tensor of shape([batch], K, F)
, representing the node features of the output.K
is the number of output nodes and depends on the specific pooling strategy;A_pool
: Tensor or SparseTensor of shape([batch], K, K)
representing the adjacency matrix of the output;I_pool
: (only ifI
was given as input) Tensor of integers with shape(K, )
representing the batch index of the output;S_pool
: (ifreturn_sel=True
) Tensor or SparseTensor representing the supernode assignments;
API
pool(X, A, I, **kwargs)
: pools the graph and returns the reduced node features and adjacency matrix. If the batch indexI
is notNone
, a reduced version ofI
will be returned as well. Any givenkwargs
will be passed as keyword arguments toselect()
,reduce()
andconnect()
if any matching key is found. The mandatory arguments ofpool()
(X
,A
, andI
) must be computed incall()
by callingself.get_inputs(inputs)
.select(X, A, I, **kwargs)
: computes supernode assignments mapping the nodes of the input graph to the nodes of the output.reduce(X, S, **kwargs)
: reduces the supernodes to form the nodes of the pooled graph.connect(A, S, **kwargs)
: connects the reduced supernodes.reduce_index(I, S, **kwargs)
: helper function to reduce the batch index (only called ifI
is given as input).
When overriding any function of the API, it is possible to access the
true number of nodes of the input (N
) as a Tensor in the instance variable
self.N
(this is populated by self.get_inputs()
at the beginning of
call()
).
Arguments:
return_sel
: ifTrue
, the Tensor used to represent supernode assignments will be returned withX_pool
,A_pool
, andI_pool
;