-
Notifications
You must be signed in to change notification settings - Fork 6
/
decomposition.py
132 lines (110 loc) · 4.55 KB
/
decomposition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import numpy as np
import image_util
# from numba import vectorize,jitclass,jit
class IntrinsicDecomposition(object):
""" Current state of a reconstruction. All entries (except ``input``) are
mutable. """
def __init__(self, params, input):
self._input = input
self.params = params
# iteration number
self.iter_num = None
# stage 1 or 2 (each iteration has 2 stages)
self.stage_num = None
# labels ("x" variable in the paper), where "_nz" indicates that only the
# nonmasked entries are stored.
self.labels_nz = None
# reflectance intensity (obtained from kmeans)
self.intensities = None
# reflectance chromaticity (obtained from kmeans)
self.chromaticities = None
# store here for visualization only
self.shading_target = None
def copy(self):
ret = IntrinsicDecomposition(self.params, self.input)
ret.iter_num = self.iter_num
ret.stage_num = self.stage_num
ret.labels_nz = self.labels_nz.copy()
ret.intensities = self.intensities.copy()
ret.chromaticities = self.chromaticities.copy()
if self.shading_target is not None:
ret.shading_target = self.shading_target.copy()
return ret
def get_r_s_nz(self):
""" Return (reflectance, shading), with just the nonmasked entries """
s_nz = self.input.image_gray_nz / self.intensities[self.labels_nz]
r_nz = self.input.image_rgb_nz / np.clip(s_nz, 1e-4, 1e5)[:, np.newaxis]
assert s_nz.ndim == 1 and r_nz.ndim == 2 and r_nz.shape[1] == 3
return r_nz, s_nz
def get_r_s(self):
""" Return (reflectance, shading), in the full (rows, cols) shape """
r_nz, s_nz = self.get_r_s_nz()
r = np.zeros((self.input.rows, self.input.cols, 3), dtype=r_nz.dtype)
s = np.zeros((self.input.rows, self.input.cols), dtype=s_nz.dtype)
r[self.input.mask_nz] = r_nz
s[self.input.mask_nz] = s_nz
assert s.ndim == 2 and r.ndim == 3 and r.shape[2] == 3
return r, s
def get_r_gray(self):
r_nz = self.intensities[self.labels_nz]
r = np.zeros((self.input.rows, self.input.cols), dtype=r_nz.dtype)
r[self.input.mask_nz] = r_nz
return r
def get_labels_visualization(self):
#colors = image_util.n_distinct_colors(self.nlabels + 1)
colors = self.get_reflectances_rgb()
colors = np.vstack((colors, [0.0, 0.0, 0.0]))
labels = self.get_labels()
labels[labels == -1] = self.nlabels
v = colors[labels, :]
return v
def get_reflectances_rgb(self):
nlabels = self.intensities.shape[0]
rgb = np.zeros((nlabels, 3))
s = 3.0 * self.intensities
r = self.chromaticities[:, 0]
g = self.chromaticities[:, 1]
b = 1.0 - r - g
rgb[:, 0] = s * r
rgb[:, 1] = s * g
rgb[:, 2] = s * b
return rgb
@property
def nlabels(self):
return self.intensities.shape[0]
@property
def input(self):
return self._input
def get_labels(self):
""" Returns labels, expanded to the full image shape, with masked
entries having a label of -1 """
labels = np.empty((self.input.rows, self.input.cols), dtype=np.int32)
labels.fill(-1)
labels[self.input.mask_nz] = self.labels_nz
return labels
def save(self, solver, out_dir, save_extra=False, id=None):
""" Save results to a directory """
if not id:
id = self.input.id
if not id:
raise ValueError("Need an id for saving")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
basename = os.path.join(out_dir, str(self.input.id))
r, s = self.get_r_s()
r_filename = '%s-r.png' % basename
image_util.save(r_filename, r, mask_nz=solver.input.mask_nz,
rescale=True)
s_filename = '%s-s.png' % basename
image_util.save(s_filename, s, mask_nz=solver.input.mask_nz,
rescale=True)
if save_extra:
r_gray_filename = '%s-r-gray.png' % basename
r_gray = self.get_r_gray()
image_util.save(r_gray_filename, r_gray,
mask_nz=solver.input.mask_nz, rescale=True)
labels_filename = '%s-labels.png' % basename
labels_image = self.get_labels_visualization()
image_util.save(labels_filename, labels_image,
mask_nz=solver.input.mask_nz, rescale=True)