-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecoder.py
31 lines (29 loc) · 1.34 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class TransformerDecoder(keras.layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.causal_self_attention = MultiHeadAttention(embed_dim=embed_dim, h=num_heads)
self.cross_attention = MultiHeadAttention(embed_dim=embed_dim, h=num_heads)
self.feed_forward = keras.Sequential(
[keras.layers.Dense(dense_dim, activation="relu"),
keras.layers.Dense(embed_dim),]
)
self.layer_norm_1 = keras.layers.LayerNormalization()
self.layer_norm_2 = keras.layers.LayerNormalization()
self.layer_norm_3 = keras.layers.LayerNormalization()
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"dense_dim": self.dense_dim,
"num_heads": self.num_heads,
})
return config
def call(self, x, context):
# Post layer normalization + residual connections
x = self.layer_norm_1(x + self.causal_self_attention(q=x, k=x, v=x, use_causal_mask=True))
x = self.layer_norm_2(x + self.cross_attention(q=x, k=context, v=context))
x = self.layer_norm_3(x + self.feed_forward(x))
return x