-
Notifications
You must be signed in to change notification settings - Fork 2
/
graph.py
78 lines (67 loc) · 3.14 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import print_function
from keras import activations, initializers, constraints
from keras import regularizers
from keras.engine import Layer
import keras.backend as K
class GraphConvolution(Layer):
"""Basic graph convolution layer as in https://arxiv.org/abs/1609.02907"""
def __init__(self, units, support=1,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform', #Gaussian distribution
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(GraphConvolution, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.supports_masking = True
self.support = support
assert support >= 1
def compute_output_shape(self, input_shapes):
features_shape = input_shapes[0]
output_shape = (features_shape[0], self.units)
return output_shape # (batch_size, output_dim)
def build(self, input_shapes):
features_shape = input_shapes[0]
assert len(features_shape) == 2
input_dim = features_shape[1]
self.kernel = self.add_weight(shape=(input_dim * self.support,self.units),
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(self.units,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.built = True
def call(self, inputs, mask=None):
features = inputs[0] # X_in =X
basis = inputs[1:] # G= A_
supports = list()
for i in range(self.support):
supports.append(K.dot(basis[i], features))
supports = K.concatenate(supports, axis=1)
output = K.dot(supports, self.kernel)
if self.bias:
output += self.bias
return self.activation(output)