-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsinusoidal_positional_embedding.py
116 lines (104 loc) · 4.27 KB
/
sinusoidal_positional_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Optional
import torch
import torch.onnx.operators
from fairseq import utils
from torch import Tensor, nn
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024, args={}):
super().__init__()
self.args = args
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.onnx_trace = False
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.max_positions = int(1e5)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
chunks_num=None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bspair = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return (
self.weights.index_select(index=self.padding_idx + pos, dim=0)
.unsqueeze(1)
.repeat(bsz, 1, 1)
)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
use_global_positions = getattr(self.args, 'global_positions', 0)
if use_global_positions:
# use `chunks_num` calculated in unfold
ins = input.view(input.shape[0] // chunks_num, -1)
positions = utils.make_positions(
ins, self.padding_idx, onnx_trace=self.onnx_trace
)
positions = positions.view(input.shape[0], -1)
else:
positions = utils.make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat(
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
)
embeddings = torch.onnx.operators.reshape_from_tensor_shape(
flat_embeddings, embedding_shape
)
return embeddings
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
)