-
Notifications
You must be signed in to change notification settings - Fork 0
/
positional_embedding.py
35 lines (31 loc) · 1.32 KB
/
positional_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.modules.learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
learned: bool = False,
args: dict = {},
):
if learned:
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if padding_idx is not None:
num_embeddings = num_embeddings + padding_idx + 1
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
args=args,
)
return m