-
Notifications
You must be signed in to change notification settings - Fork 3
/
AsymCheegerCutPool.py
221 lines (186 loc) · 8.11 KB
/
AsymCheegerCutPool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
import tensorflow.keras.backend as K
from spektral.layers import ops
from spektral.layers.pooling.src import SRCPool
class AsymCheegerCutPool(SRCPool):
r"""
An Asymmetric Cheeger Cut Pooling layer from the paper
> [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
> Jonas Berg Hansen and Filippo Maria Bianchi
**Mode**: single, batch
This layer learns a soft clustering of the input graph as follows:
$$
\begin{align}
\S &= \textrm{MLP}(\X); \\
\X' &= \S^\top \X \\
\A' &= \S^\top \A \S; \\
\end{align}
$$
where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output.
The layer includes two auxiliary loss terms/components:
A graph total variation loss given by
$$
L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|,
$$
where $$E$$ is the number of edges/links, $$K$$ is the number of clusters or output nodes, and $$N$$ is the number of nodes.
An asymmetrical norm term given by
$$
L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_\rho (\s_{:,k})||_{1, \rho}}{N(K-1)},
$$
The layer can be used without a supervised loss to compute node clustering by
minimizing the two auxiliary losses.
**Input**
- Node features of shape `(batch, n_nodes_in, n_node_features)`;
- Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`;
**Output**
- Reduced node features of shape `(batch, n_nodes_out, n_node_features)`;
- If `return_selection=True`, the selection matrix of shape
`(batch, n_nodes_in, n_nodes_out)`.
**Arguments**
- `k`: number of output nodes;
- `mlp_hidden`: list of integers, number of hidden units for each hidden layer in
the MLP used to compute cluster assignments (if `None`, the MLP has only one output
layer);
- `mlp_activation`: activation for the MLP layers;
- `return_selection`: boolean, whether to return the selection matrix;
- `use_bias`: use bias in the MLP;
- `totvar_coeff`: coefficient for graph total variation loss component;
- `balance_coeff`: coefficient for asymmetric norm loss component;
- `softmax_temparture`: temperature parameter for softmax activation at the end of the MLP;
- `kernel_initializer`: initializer for the weights of the MLP;
- `bias_regularizer`: regularization applied to the bias of the MLP;
- `kernel_constraint`: constraint applied to the weights of the MLP;
- `bias_constraint`: constraint applied to the bias of the MLP;
"""
def __init__(self,
k,
mlp_hidden=None,
mlp_activation="relu",
return_selection=False,
use_bias=True,
totvar_coeff=1.0,
balance_coeff=1.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
k=k,
mlp_hidden=mlp_hidden,
mlp_activation=mlp_activation,
return_selection=return_selection,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)
self.k = k
self.mlp_hidden = mlp_hidden if mlp_hidden else []
self.mlp_activation = mlp_activation
self.totvar_coeff = totvar_coeff
self.balance_coeff = balance_coeff
def build(self, input_shape):
layer_kwargs = dict(
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
kernel_regularizer=self.kernel_regularizer,
bias_regularizer=self.bias_regularizer,
kernel_constraint=self.kernel_constraint,
bias_constraint=self.bias_constraint,
)
self.mlp = Sequential(
[
Dense(channels, self.mlp_activation, **layer_kwargs)
for channels in self.mlp_hidden
]
+ [Dense(self.k, "softmax", **layer_kwargs)]
)
super().build(input_shape)
def call(self, inputs, mask=None):
x, a, i = self.get_inputs(inputs)
return self.pool(x, a, i, mask=mask)
def select(self, x, a, i, mask=None):
s = self.mlp(x)
if mask is not None:
s *= mask[0]
# Total variation loss
cut_loss = self.totvar_loss(a, s)
if K.ndim(a) == 3:
cut_loss = K.mean(cut_loss)
self.add_loss(self.totvar_coeff * cut_loss)
# Asymmetric l1-norm loss
bal_loss = self.balance_loss(s)
if K.ndim(a) == 3:
bal_loss = K.mean(bal_loss)
self.add_loss(self.balance_coeff * bal_loss)
return s
def reduce(self, x, s, **kwargs):
return ops.modal_dot(s, x, transpose_a=True)
def connect(self, a, s, **kwargs):
a_pool = ops.matmul_at_b_a(s, a)
return a_pool
def reduce_index(self, i, s, **kwargs):
i_mean = tf.math.segment_mean(i, i)
i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k)
return i_pool
def totvar_loss(self, a, s):
if K.is_sparse(a):
index_i = a.indices[:, 0]
index_j = a.indices[:, 1]
n_edges = float(len(a.values))
loss = tf.math.reduce_sum(a.values[:, tf.newaxis] *
tf.math.abs(tf.gather(s, index_i) -
tf.gather(s, index_j)),
axis=(-2, -1))
else:
n_edges = tf.cast(tf.math.count_nonzero(
a, axis=(-2, -1)), dtype=s.dtype)
n_nodes = tf.shape(a)[-1]
if K.ndim(a) == 3:
loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s[:, tf.newaxis, ...] -
tf.repeat(s[..., tf.newaxis, :],
n_nodes, axis=-2)), axis=-1),
axis=(-2, -1))
else:
loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s -
tf.repeat(s[..., tf.newaxis, :],
n_nodes, axis=-2)), axis=-1),
axis=(-2, -1))
loss *= 1 / (2 * n_edges)
return loss
def balance_loss(self, s):
n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype)
# k-quantile
idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32)
med = tf.math.top_k(tf.linalg.matrix_transpose(s),
k=idx).values[..., -1]
# Asymmetric l1-norm
if K.ndim(s) == 2:
loss = s - med
else:
loss = s - med[:, tf.newaxis, ...]
loss = ((tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) +
(tf.cast(loss < 0, loss.dtype) * loss * -1.))
loss = tf.math.reduce_sum(loss, axis=(-2, -1))
loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss)
return loss
def get_config(self):
config = {
"k": self.k,
"mlp_hidden": self.mlp_hidden,
"mlp_activation": self.mlp_activation,
"totvar_coeff": self.totvar_coeff,
"balance_coeff": self.balance_coeff
}
base_config = super().get_config()
return {**base_config, **config}