-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpositionalEmbedding.py
29 lines (24 loc) · 1.34 KB
/
positionalEmbedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class PositionalEmbedding(keras.layers.Layer):
def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
super().__init__(**kwargs)
self.token_embeddings = keras.layers.Embedding(input_dim=input_dim, output_dim=output_dim) # token embedding layer
self.position_embeddings = keras.layers.Embedding(input_dim=sequence_length, output_dim=output_dim) # position embedding layer
self.sequence_length = sequence_length
self.input_dim = input_dim
self.output_dim = output_dim
def call(self, inputs):
embedded_tokens = self.token_embeddings(inputs) # embed the tokens
length = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=length, delta=1) # create the positional information
embedded_positions = self.position_embeddings(positions) # embed the positions
return embedded_tokens + embedded_positions # add the token and position embeddings to create the positional embeddings
def compute_mask(self, inputs, mask=None):
return keras.ops.not_equal(inputs, 0)
def get_config(self):
config = super(PositionalEmbedding, self).get_config()
config.update({
"input_dim": self.input_dim,
"output_dim": self.output_dim,
"sequence_length": self.sequence_length,
})
return config