-
Notifications
You must be signed in to change notification settings - Fork 3
/
ColTranColorUpsampler.py
60 lines (38 loc) · 1.46 KB
/
ColTranColorUpsampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
from utils import *
from GrayscaleEncoder import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ColTranColorUpsampler(nn.Module) :
"""
AxialTransformer
"""
def __init__(self, D, NColor):
super(ColTranColorUpsampler, self).__init__()
self.D = D
self.NColor = NColor
self.embedding_x_g = nn.Embedding(NColor,D) # BATCH * M * N * D
self.embedding_x_rgb = nn.Sequential(nn.Embedding(NColor,D), nn.Embedding(NColor,D), nn.Embedding(NColor,D))
self.grayscale_encoder = GrayscaleEncoder(D)
# 256 nuances de couleurs
self.linear = nn.Linear(D,NColor)
def forward(self, x_g, x_s_c) :
"""
=INPUT=
x_s_c : low resolution image, a spatially downsampled representation of x
: M * N * 3
x_g :
: M * N * 1
=RETURN=
x : colorized image (high resolution)
: H * W * 3
"""
batch,row,col,channel = x_s_c.shape
pe = positionalencoding2d(self.D, row, col, batch)
emb_g = pe + self.embedding_x_g(x_g)
out = torch.zeros(batch, row, col, channel, self.NColor)
for k in range(channel):
emb_k = pe + self.embedding_x_rgb[k](x_s_c[:,:,:,k])
input_encoder = emb_g + emb_k
out_encoder = self.grayscale_encoder(input_encoder)
out[:,:,:,k] = self.linear(out_encoder)
return out.argmax(-1), out