forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TAHIN.py
192 lines (154 loc) · 7.64 KB
/
TAHIN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import GATConv
#Semantic attention in the metapath-based aggregation (the same as that in the HAN)
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False)
)
def forward(self, z):
'''
Shape of z: (N, M , D*K)
N: number of nodes
M: number of metapath patterns
D: hidden_size
K: number of heads
'''
w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) # (N, D * K)
#Metapath-based aggregation (the same as the HANLayer)
class HANLayer(nn.Module):
def __init__(self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(len(meta_path_patterns)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu,
allow_zero_in_degree=True))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.meta_path_patterns = list(tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns)
self._cached_graph = None
self._cached_coalesced_graph = {}
def forward(self, g, h):
semantic_embeddings = []
#obtain metapath reachable graph
if self._cached_graph is None or self._cached_graph is not g:
self._cached_graph = g
self._cached_coalesced_graph.clear()
for meta_path_pattern in self.meta_path_patterns:
self._cached_coalesced_graph[meta_path_pattern] = dgl.metapath_reachable_graph(
g, meta_path_pattern)
for i, meta_path_pattern in enumerate(self.meta_path_patterns):
new_g = self._cached_coalesced_graph[meta_path_pattern]
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
#Relational neighbor aggregation
class RelationalAGG(nn.Module):
def __init__(self, g, in_size, out_size, dropout=0.1):
super(RelationalAGG, self).__init__()
self.in_size = in_size
self.out_size = out_size
#Transform weights for different types of edges
self.W_T = nn.ModuleDict({
name : nn.Linear(in_size, out_size, bias = False) for name in g.etypes
})
#Attention weights for different types of edges
self.W_A = nn.ModuleDict({
name : nn.Linear(out_size, 1, bias = False) for name in g.etypes
})
#layernorm
self.layernorm = nn.LayerNorm(out_size)
#dropout layer
self.dropout = nn.Dropout(dropout)
def forward(self, g, feat_dict):
funcs={}
for srctype, etype, dsttype in g.canonical_etypes:
g.nodes[dsttype].data['h'] = feat_dict[dsttype] #nodes' original feature
g.nodes[srctype].data['h'] = feat_dict[srctype]
g.nodes[srctype].data['t_h'] = self.W_T[etype](feat_dict[srctype]) #src nodes' transformed feature
#compute the attention numerator (exp)
g.apply_edges(fn.u_mul_v('t_h','h','x'),etype=etype)
g.edges[etype].data['x'] = torch.exp(self.W_A[etype](g.edges[etype].data['x']))
#first update to compute the attention denominator (\sum exp)
funcs[etype] = (fn.copy_e('x', 'm'), fn.sum('m', 'att'))
g.multi_update_all(funcs, 'sum')
funcs={}
for srctype, etype, dsttype in g.canonical_etypes:
g.apply_edges(fn.e_div_v('x', 'att', 'att'),etype=etype) #compute attention weights (numerator/denominator)
funcs[etype] = (fn.u_mul_e('h', 'att', 'm'), fn.sum('m', 'h')) #\sum(h0*att) -> h1
#second update to obtain h1
g.multi_update_all(funcs, 'sum')
#apply activation, layernorm, and dropout
feat_dict={}
for ntype in g.ntypes:
feat_dict[ntype] = self.dropout(self.layernorm(F.relu_(g.nodes[ntype].data['h']))) #apply activation, layernorm, and dropout
return feat_dict
class TAHIN(nn.Module):
def __init__(self, g, meta_path_patterns, in_size, out_size, num_heads, dropout):
super(TAHIN, self).__init__()
#embeddings for different types of nodes, h0
self.initializer = nn.init.xavier_uniform_
self.feature_dict = nn.ParameterDict({
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes
})
#relational neighbor aggregation, this produces h1
self.RelationalAGG = RelationalAGG(g, in_size, out_size)
#metapath-based aggregation modules for user and item, this produces h2
self.meta_path_patterns = meta_path_patterns
#one HANLayer for user, one HANLayer for item
self.hans = nn.ModuleDict({
key: HANLayer(value, in_size, out_size, num_heads, dropout) for key, value in self.meta_path_patterns.items()
})
#layers to combine h0, h1, and h2
#used to update node embeddings
self.user_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
self.user_layer2 = nn.Linear(2*out_size, out_size, bias=True)
self.item_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
self.item_layer2 = nn.Linear(2*out_size, out_size, bias=True)
#layernorm
self.layernorm = nn.LayerNorm(out_size)
#network to score the node pairs
self.pred = nn.Linear(out_size, out_size)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(out_size, 1)
def forward(self, g, user_key, item_key, user_idx, item_idx):
#relational neighbor aggregation, h1
h1 = self.RelationalAGG(g, self.feature_dict)
#metapath-based aggregation, h2
h2 = {}
for key in self.meta_path_patterns.keys():
h2[key] = self.hans[key](g, self.feature_dict[key])
#update node embeddings
user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
user_emb = self.user_layer1(user_emb)
item_emb = self.item_layer1(item_emb)
user_emb = self.user_layer2(torch.cat((user_emb, self.feature_dict[user_key]), 1))
item_emb = self.item_layer2(torch.cat((item_emb, self.feature_dict[item_key]), 1))
#Relu
user_emb = F.relu_(user_emb)
item_emb = F.relu_(item_emb)
#layer norm
user_emb = self.layernorm(user_emb)
item_emb = self.layernorm(item_emb)
#obtain users/items embeddings and their interactions
user_feat = user_emb[user_idx]
item_feat = item_emb[item_idx]
interaction = user_feat*item_feat
#score the node pairs
pred = self.pred(interaction)
pred = self.dropout(pred) #dropout
pred = self.fc(pred)
pred = torch.sigmoid(pred)
return pred.squeeze(1)