-
Notifications
You must be signed in to change notification settings - Fork 3
/
ColTranSpatialUpsampler.py
63 lines (40 loc) · 1.54 KB
/
ColTranSpatialUpsampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
from utils import *
from GrayscaleEncoder import *
class ColTranSpatialUpsampler(nn.Module) :
"""
AxialTransformer
"""
def __init__(self, D, NColor, H, W):
super(ColTranSpatialUpsampler, self).__init__()
self.D = D
self.H = H
self.W = W
self.NColor = NColor
self.embedding_x_g = nn.Embedding(NColor,D) # BATCH * H * W * D
self.embedding_x_rgb = nn.Sequential(nn.Embedding(NColor,D), nn.Embedding(NColor,D), nn.Embedding(NColor,D))
self.grayscale_encoder = GrayscaleEncoder(D)
# 256 nuances de couleurs
self.linear = nn.Linear(D,NColor)
def forward(self, x_g, x_s) :
"""
=INPUT=
x_s : low resolution image, a spatially downsampled representation of x
: M * N * 3
x_g :
: H * W * 1
=RETURN=
x : colorized image (high resolution)
: H * W * 3
"""
x_s = F.interpolate(x_s.permute(0,3,1,2).float(),size=(self.H, self.W),mode="bilinear").permute(0,2,3,1).long()
batch,row,col,channel = x_s.shape
pe = positionalencoding2d(self.D, row, col, batch)
emb_g = pe + self.embedding_x_g(x_g)
out = torch.zeros(batch, row, col, channel, self.NColor)
for k in range(channel):
emb_k = pe + self.embedding_x_rgb[k](x_s[:,:,:,k])
input_encoder = emb_g + emb_k
out_encoder = self.grayscale_encoder(input_encoder)
out[:,:,:,k] = self.linear(out_encoder)
return out.argmax(-1), out