-
Notifications
You must be signed in to change notification settings - Fork 1
/
util.py
244 lines (203 loc) · 7.31 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
'''
util.py
Last edited by: GunGyeom James Kim
Last edited at: Oct 24th, 2023
CS 7180: Advnaced Perception
File containing utility functions
variable:
cam2rgb - Global variable, transformation matrix for cam to RGB
function:
read_16bit_png - read 16bit png file using torch
angularLoss - calculate accumulated angular loss in degrees
illuminate - Linearize, illuminate, map to RGB and gamma correct
L2sRGB - Map linear chromaticity space to sRGB chromaticity space
to_rgb - Map input to rgb chromaticity space
class:
MaxResize - scale input resizing longer size to max
ContrastNormalization - apply histogram stretching to normalize contrast
RandomPatches - randomly crop image to number of 32x32 patches
'''
# built-in
import math
from random import sample
# third-party
import numpy as np
# torch
import torch
from torchvision.io import read_file
from torchvision.transforms import functional as F
cam2rgb = np.array([
1.8795, -1.0326, 0.1531,
-0.2198, 1.7153, -0.4955,
0.0069, -0.5150, 1.5081,]).reshape((3, 3))
def read_16bit_png(file):
'''
Return 16 bit image
Parameter:
file(str or Path) - 16bit image file
Return:
16bit image
'''
data = read_file(file)
return torch.ops.image.decode_png(data, 0, True)
def angularLoss(xs, ys, singleton=False):
'''
Return accumulated angular loss in degrees
Parameter:
xs(tensor or sequence of tensors) - sequence of tensors to calculate angular loss
ys(tensor or sequence of tensors) - sequence of tensors to calculate angular loss
Return:
output(float) - accumulated angular loss in degrees
'''
if singleton:
if torch.count_nonzero(xs[0]).item() == 0: return 180
return torch.rad2deg(torch.arccos(torch.nn.functional.cosine_similarity(xs,ys, dim=-1))).item()
output = 0
for x, y in zip(xs, ys):
if torch.count_nonzero(x).item() == 0: output += 180
else: output += torch.rad2deg(torch.arccos(torch.nn.functional.cosine_similarity(x,y, dim=0))).item()
return output
def illuminate(img, illum):
'''
Linearize, illuminate, map to RGB and gamma correct
Parameter:
img(tensor) - image to process
idx(int) - index to get illumination
Return:
output(numpy.ndarray) -
'''
linearize = ContrastNormalization()
linearized_img = linearize(img).permute(1,2,0).cpu().numpy() # h,w,c -> c,h,w
illum = illum.cpu().numpy()
white_balanced_image = linearized_img/illum
rgb_img = np.dot(white_balanced_image, cam2rgb.T)
rgb_img = np.clip(rgb_img, 0, 1)**(1/2.2)
return (rgb_img*255).astype(np.uint8)
def L2sRGB(linImg):
'''
Map linear chromaticity space to sRGB chromaticity space
Parameter:
linImg(tensor) - image in linear chromaticity space
Return:
linImg(tensor) - image map to sRGB chromaticity space
'''
low_mask = linImg <= 0.0031308
high_mask = linImg > 0.0031308
linImg[low_mask] *= 12.92
linImg[high_mask] = 1.055 * linImg[high_mask]**(1/2.4) - 0.055
return linImg
def to_rgb(inputs):
'''
Map input to rgb chromaticity space (r,g,b in [0,1] such that r+g+b = 1)
Parameter:
input(tensor) - num_patches x 3, input in arbitrary chromaticity space
Return:
input(tensor) - input mapped to rgb chromaticity space
'''
num_patches = inputs.shape[0]
if num_patches == 1: return inputs[0] / torch.sum(inputs[0])
for idx in range(num_patches):
inputs[idx] = inputs[idx].clone() / torch.sum(inputs[idx])
return inputs
#################
### Transform ###
#################
class MaxResize:
'''
Downscale input while longer side is capped at self.max_length
'''
def __init__(self, max_length):
'''
Constructor
Parameters:
max_length(int) - maximum length to downscale
'''
self.max_length = max_length
def __call__(self, img):
'''
Return downscaled image while longer side is capped at self.max_length
Parameter:
img(tensor) - image to downscale
Return:
downscaled image wheer longer side is capped at self.max_length
'''
_, h, w = img.size()
ratio = float(h) / float(w)
if ratio > 1: # h > w
w0 = math.ceil(self.max_length / ratio)
return F.resize(img, (self.max_length, w0), antialias=True)
else: # h <= w
h0 = math.ceil(self.max_length / ratio)
return F.resize(img, (h0, self.max_length), antialias=True)
class ContrastNormalization:
'''
Apply Global Histogram Stretching to normalize contrast
'''
def __init__(self, black_lvl=2048):
'''
Constructor
Parameter:
black_lvl(int, optional) - value of black orginall captured by camera
'''
self.black_lvl = black_lvl
def __call__(self, img):
'''
Return contrast normalized image
Parameters:
img(tensor) - image to contrast normalize
Return:
output(tensor) - contrast normalized image in [0,1]
'''
saturation_lvl = torch.max(img)
return torch.clamp((img - self.black_lvl)/(saturation_lvl - self.black_lvl),0,1)
class RandomPatches:
'''
Randomly crop image to number of 32x32 patches
'''
def __init__(self, patch_size, num_patches):
'''
Constructor
Parameters:
patch_size(int) - size of patch
num_patches(int) - number of patch to return
'''
self.patch_size = patch_size
self.num_patches = num_patches
def __call__(self, img):
'''
Return sequences of 32x32 patches that was randomly cropped from image
Parameter:
img(tensor) - image to radomly crop 32x32 patches
Return:
sequences of 32x32 patches that was randomly cropped from image
'''
# specific to SimpleCube++
MASK_HEIGHT = 250
MASK_WIDTH = 175
# assign and initiate variables
_, h, w = img.size()
diameter = self.patch_size
radius = self.patch_size // 2
coords = set()
center = list()
# populate candidate for center of patches
for row in range(radius, h-radius):
for col in range(radius, w-radius):
if (row < h-radius-MASK_HEIGHT or col < w-radius-MASK_WIDTH): coords.add((row, col)) # check whether it overlap with masked rectangle
# sample center for patches
for _ in range(self.num_patches):
valid = False
while coords and not valid:
y0, x0 = sample(coords, 1)[0]
coords.remove((y0, x0))
valid = True
for y, x in center:
if not valid: break # if overlap, try another one
valid &= abs(y-y0) > diameter and abs(x-x0) > diameter # check whether it overlap with other patches
if valid: center.append((y0,x0)) # if it doesn't overlap, sample it
# sample patches according to chosen centers
patches = []
for y,x in center:
patch = img[:, y-16:y+16, x-16:x+16].type(torch.float32)
patches.append(patch)
return torch.stack(patches, dim=0) # list of tensors -> sequence(tensor) of tensors