-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrefactorgnn_demo_adagrad.py
203 lines (173 loc) · 9.36 KB
/
refactorgnn_demo_adagrad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""A demo of ReFactor GNN induced by DistMult optimised with AdaGrad over a cross-entropy loss + N3 regularizer .
Implemented using Pytorch Geometric
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.nn.conv import MessagePassing
def parse_infer_cmd(infer_with):
return int(infer_with.split('-')[0])
class ReFactorConv(MessagePassing):
def __init__(self, n_ent=None, n_rel=None, n_hid=None,
fm_alpha=0.1, fm_lmbda=5e-3, fm_optim='SGD', fm_score_func='DistMult',
train_mp_init='input', infer_with='message-passing'):
""" ReFactorConv implements the ReFactor Layer in our paper (see sec 4 in https://arxiv.org/pdf/2207.09980.pdf)
Args:
train_mp_init: str, specify how to two options, default to 'input'
option1: 'input', each series of message-passings starts with raw input
option2: 'state_caching', each series of message-passings starts with cached states from previous calculation
infer_with: str, two options
option1: 'input', simply use input as node states;
option2: 'state_caching', use cashed node states;
option3: 'k-message_passing', use node states produced by a message-passing;
"""
super(ReFactorConv, self).__init__()
assert fm_optim in ['SGD', 'AdaGrad']
self.fm_alpha, self.fm_lmbda, self.fm_optim, self.fm_score_func = fm_alpha, fm_lmbda, fm_optim, fm_score_func
self.train_mp_init, self.infer_with = train_mp_init, infer_with
self.rel_emb = nn.Embedding(n_rel, n_hid)
self.rel_emb.weight.data *= 1e-3
if self.train_mp_init == 'state_caching' or self.infer_with == 'state_caching': # node state cache/memory
self.ent_state_cache = nn.Embedding(n_ent, n_hid)
self.ent_state_cache.weight.data *= 1e-3
self.ent_state_cache.weight.requires_grad = False # Caution !!! The node state memory is not updated by auto-differentiation
self.init_node_states = None
if self.fm_optim == 'AdaGrad':
self.epsilon = 1e-10
self.squared_grad_sum = nn.Embedding(n_ent, n_hid)
self.squared_grad_sum.weight.data *= 0.0
self.squared_grad_sum.weight.requires_grad = False
def forward(self, x, edge_index, edge_type, g_node_idx=None, clr_ent_state=False):
if self.training == False: # inference-time
if self.infer_with.endswith('-message-passing'):
l_round = parse_infer_cmd(self.infer_with) # on-the-fly l-round message-passing
if self.fm_optim == 'AdaGrad': # clear optim states for inference, ensure that train/test optim state doesn't interplay with each other
old_optim_state = self.optim_state_pull()
self.optim_state_clear()
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_type) # use to incur message_and_aggregate
for _ in range(l_round):
x = self.propagate(adj, x=x, g_node_idx=g_node_idx) # l-round MESSAGE-PASSING
x_new = x
if self.fm_optim == 'AdaGrad': # restore optim states for training, TODO: convert it to decorator https://stackoverflow.com/questions/49973804/define-a-decorator-as-method-inside-class
self.optim_state_restore(old_optim_state)
elif self.infer_with == 'state_caching':
x_new = self.pull(g_node_idx)
elif self.infer_with == 'input':
x_new = x
else: # training time
if clr_ent_state and (self.train_mp_init == 'state_caching' or self.infer_with == 'state_caching'):
self.clear() # restart the message-passing by clearing node states/optim states
if self.train_mp_init == 'state_caching': # maybe decorator
x = self.pull(g_node_idx)
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_type) # use to incur message_and_aggregate
x_new = self.propagate(adj, x=x, g_node_idx=g_node_idx)
if self.train_mp_init == 'state_caching' or self.infer_with == 'state_caching':
self.push(g_node_idx, x_new)
return x_new
def pull(self, node_idx):
"""Read from the node state cache"""
x = self.ent_state_cache(node_idx)
return x
def push(self, node_idx, node_state):
"""Write to the node state cache"""
with torch.no_grad():
self.ent_state_cache.weight.data[node_idx] = node_state
def clear(self):
"""Reset the node state cache to init_node_states
"""
if self.init_node_states is None:
self.ent_state_cache.weight.data.normal_()
self.ent_state_cache.weight.data *= 1e-3
self.ent_state_cache.weight.requires_grad = False #!!!
else:
print('Reset ent_emb to init_node_states!')
self.ent_state_cache.weight.data = self.init_node_states.clone().detach().type_as(self.ent_state_cache.weight.data)
self.ent_state_cache.weight.requires_grad = False #!!!
if self.fm_optim == 'AdaGrad':
self.optim_state_clear()
def set_init_node_states(self, node_features):
self.init_node_states = node_features.clone().detach()
def optim_state_restore(self, optim_state):
""" AdaGrad, optimiser state restore at the end of inference
"""
self.squared_grad_sum.weight.data = optim_state
self.squared_grad_sum.weight.requires_grad = False
def optim_state_clear(self):
"""AdaGrad, optimiser state clear at the beginning of inference
"""
self.squared_grad_sum.weight.data.normal_()
self.squared_grad_sum.weight.data *= 0.0
self.squared_grad_sum.weight.requires_grad = False
def optim_state_pull(self, node_idx=None):
"""AdaGrad, optimiser state read
"""
if node_idx == None:
return self.squared_grad_sum.weight.data.clone().detach().type_as(self.squared_grad_sum.weight.data)
else:
return self.squared_grad_sum(node_idx)
def optim_state_accum(self, node_idx, node_state):
"""AdaGrad, accumulate running grad ** 2
Args:
node_state: runnning grad ** 2
"""
with torch.no_grad():
self.squared_grad_sum.weight.data[node_idx] += node_state
def embed_rel(self, p):
"""Get relation embedding"""
rp = self.rel_emb(p)
return rp
def fm_score(self, hv, rp, ws):
"""Score function, corresponding to Gamma in the paper
Might lead to huge memory consumption if |B| (num queries) or N (num candidate nodes) is big
Args:
hv: subject representation, |B| X K
rp: predicate representation, |B| X K
ws: representations for all the objects, N X K
Return: probabilities for P(w|v, r), |B| x N (num_queries x num_nodes)
"""
if self.fm_score_func == 'DistMult':
Z = F.softmax(hv * rp @ ws.t(), dim=1)
return Z
def message_and_aggregate(self, adj, x=None):
""" Compute the aggregated message,
avoiding explicitly materialising each message vector
args:
adj: sparse tensor indicating index, transposed
"""
n_nodes = x.shape[0]
v, w, p = adj.coo()
hv, rp, hw = x[v], self.embed_rel(p), x[w]
Z = self.fm_score(hv, rp, ws=x) # |B| x N
Zw = torch.gather(Z, dim=1, index=w.unsqueeze(1)) # |B| x 1
Zv = torch.gather(Z, dim=1, index=v.unsqueeze(1)) # |B| x 1
if self.fm_score_func == 'DistMult':
aggr_out = []
'''Direction 1: w2v, z_{v<-*} = - grad_v'''
msg_fit = rp * hw - rp * (Z @ x)
msg_reg = - self.fm_lmbda * 3 * (hv ** 2) * torch.sign(hv)
aggr_other2v = scatter(src=msg_fit + msg_reg,
index=v, dim=0, dim_size=n_nodes)
aggr_out.append(aggr_other2v)
'''Direction 2: v2w, z_{w<-v} = - grad_w'''
msg_fit = (1 - Zw) * rp * hv
msg_reg = - self.fm_lmbda * 3 * (hw ** 2) * torch.sign(hw)
aggr_v2neighbor = scatter(src=msg_fit + msg_reg,
index=w, dim=0, dim_size=n_nodes)
aggr_out.append(aggr_v2neighbor)
'''z_{u<-v}= - grad_u'''
msg_fit = - Z.T @ (rp * hv) # N x K
msg_fit[w, :] += Zw * (rp * hv)
msg_fit[v, :] += Zv * (rp * hv)
aggr_v2negative = scatter(src=msg_fit, index=torch.arange(n_nodes).type_as(v), dim=0, dim_size=n_nodes)
aggr_out.append(aggr_v2negative)
aggr_out.append(aggr_v2negative)
aggr_out = torch.stack(aggr_out, dim=0).sum(dim=0)
return aggr_out
def update(self, aggr_out, x, g_node_idx):
if self.fm_optim == 'AdaGrad': # with Adagrad, we need to normalise the aggregated messages a bit
self.optim_state_accum(g_node_idx, (- aggr_out.data) ** 2)
state_sum = self.optim_state_pull(g_node_idx)
aggr_out = aggr_out / (torch.sqrt(state_sum) + self.epsilon)
return x + self.fm_alpha * aggr_out