-
Notifications
You must be signed in to change notification settings - Fork 12
/
GAT.py
88 lines (69 loc) · 3.99 KB
/
GAT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# -----------------------------------------------------------
# Dual Semantic Relations Attention Network (DSRAN) implementation
# "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
# Keyu Wen, Xiaodong Gu, and Qingrong Cheng
# IEEE Transactions on Circuits and Systems for Video Technology, 2020
# Writen by Keyu Wen, 2020
# ------------------------------------------------------------
import math
import torch
from torch import nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, input_graph):
nodes_q = self.query(input_graph)
nodes_k = self.key(input_graph)
nodes_v = self.value(input_graph)
nodes_q_t = self.transpose_for_scores(nodes_q)
nodes_k_t = self.transpose_for_scores(nodes_k)
nodes_v_t = self.transpose_for_scores(nodes_v)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(nodes_q_t, nodes_k_t.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in GATModel forward() function)
attention_scores = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
nodes_new = torch.matmul(attention_probs, nodes_v_t)
nodes_new = nodes_new.permute(0, 2, 1, 3).contiguous()
new_nodes_shape = nodes_new.size()[:-2] + (self.all_head_size,)
nodes_new = nodes_new.view(*new_nodes_shape)
return nodes_new
class GATLayer(nn.Module):
def __init__(self, config):
super(GATLayer, self).__init__()
self.mha = MultiHeadAttention(config)
self.fc_in = nn.Linear(config.hidden_size, config.hidden_size)
self.bn_in = nn.BatchNorm1d(config.hidden_size)
self.dropout_in = nn.Dropout(config.hidden_dropout_prob)
self.fc_int = nn.Linear(config.hidden_size, config.hidden_size)
self.fc_out = nn.Linear(config.hidden_size, config.hidden_size)
self.bn_out = nn.BatchNorm1d(config.hidden_size)
self.dropout_out = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_graph):
attention_output = self.mha(input_graph) # multi-head attention
attention_output = self.fc_in(attention_output)
attention_output = self.dropout_in(attention_output)
attention_output = self.bn_in((attention_output + input_graph).permute(0, 2, 1)).permute(0, 2, 1)
intermediate_output = self.fc_int(attention_output)
intermediate_output = F.relu(intermediate_output)
intermediate_output = self.fc_out(intermediate_output)
intermediate_output = self.dropout_out(intermediate_output)
graph_output = self.bn_out((intermediate_output + attention_output).permute(0, 2, 1)).permute(0, 2, 1)
return graph_output