-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathpositional.py
39 lines (21 loc) · 1.04 KB
/
positional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
def PositionalEncoder(image_shape,num_frequency_bands,max_frequencies=None):
*spatial_shape, _ = image_shape
coords = [ torch.linspace(-1, 1, steps=s) for s in spatial_shape ]
pos = torch.stack(torch.meshgrid(*coords), dim=len(spatial_shape))
encodings = []
if max_frequencies is None:
max_frequencies = pos.shape[:-1]
frequencies = [ torch.linspace(1.0, max_freq / 2.0, num_frequency_bands)
for max_freq in max_frequencies ]
frequency_grids = []
for i, frequencies_i in enumerate(frequencies):
frequency_grids.append(pos[..., i:i+1] * frequencies_i[None, ...])
encodings.extend([torch.sin(math.pi * frequency_grid) for frequency_grid in frequency_grids])
encodings.extend([torch.cos(math.pi * frequency_grid) for frequency_grid in frequency_grids])
enc = torch.cat(encodings, dim=-1)
enc = rearrange(enc, "... c -> (...) c")
return enc