-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultiHeadAttention.py
43 lines (38 loc) · 1.61 KB
/
multiHeadAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class MultiHeadAttention(keras.layers.Layer):
def __init__(self, embed_dim, h, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.h = h
if embed_dim % h != 0:
raise ValueError(
f"dimension of the embedding space = {embed_dim} should be divisible by number of heads = {h}"
)
self.q_linear = keras.layers.Dense(embed_dim)
self.k_linear = keras.layers.Dense(embed_dim)
self.v_linear = keras.layers.Dense(embed_dim)
self.concat_linear = keras.layers.Dense(embed_dim)
def split_heads(self, x, batch_size):
x = tf.reshape(x, shape=(batch_size, -1, self.h, self.embed_dim // self.h))
return tf.transpose(x, perm=[0, 2, 1, 3])
def concat_heads(self, x, batch_size):
x = tf.transpose(x, perm=[0, 2, 1, 3])
return tf.reshape(x, (batch_size, -1, self.embed_dim))
def call(self, q, k, v, use_causal_mask=False):
batch_size = tf.shape(k)[0]
q = self.q_linear(q)
k = self.k_linear(k)
v = self.v_linear(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
attention = scaled_dot_product_attention(q, k, v, use_causal_mask)
concat = self.concat_heads(attention, batch_size)
concat = self.concat_linear(concat)
return concat
def get_config(self):
config = super(MultiHeadAttention, self).get_config()
config.update({
"embed_dim": self.embed_dim,
"h": self.h,
})
return config