Skip to content

Commit 01e917c

Browse files
authored
Pixel Kernels (3/2): RGBD Kernels (#152)
While working on the image kernels, I realized that it might be useful to have a unified class that sample the RGBD values jointly, so I'm introducing this utility class that takes in an color kernel (from #147) and a depth kernel (from #149) to create an RGBD kernel, which supports sample and logpdf computation. This is a general class that's agnostic of the type of RGB and depth kernels -- it'll pass in the arguments it receives to _both_ of the kernels during `sample` and `logpdf` call. Since both color and depth kernels have `*args, **kwargs` as part of their function signature, they should simply ignore additional arguments that's not relevant to them. (I also just realized that I was referring to these classes as "kernels" but the actual class names are `*Distributions`... maybe I should fix these in a future PR.) I'm submitting this PR to `main` for now. Once I'm wrapping up my local changes, I'm going to work on resolving merge conflicts with `gen3d`. ## Test Plan Similar to the previous PRs, I've added some unit tests to make sure that the kernels roughly have the behaviors that we expected: ```bash pytest tests/gen3d/test_pixel_rgbd_kernels.py ```
1 parent 09b3c8c commit 01e917c

File tree

6 files changed

+174
-45
lines changed

6 files changed

+174
-45
lines changed

src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ def logpdf_per_channel(
129129
class MixturePixelColorDistribution(PixelColorDistribution):
130130
"""A distribution that generates the color of a pixel from a mixture of a
131131
truncated Laplace distribution centered around the latent color (inlier
132-
branch) and a uniform distribution (outlier branch). The mixture is
133-
controlled by the outlier_prob parameter. The support of the
132+
branch) and a uniform distribution (occluded branch). The mixture is
133+
controlled by the occluded_prob parameter. The support of the
134134
distribution is ([0, 1]^3).
135135
"""
136136

137137
color_scale: float
138138

139139
@property
140-
def _outlier_dist(self) -> PixelColorDistribution:
140+
def _occluded_dist(self) -> PixelColorDistribution:
141141
return UniformPixelColorDistribution()
142142

143143
@property
@@ -146,35 +146,35 @@ def _inlier_dist(self) -> PixelColorDistribution:
146146

147147
@property
148148
def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]:
149-
return (self._outlier_dist, self._inlier_dist)
149+
return (self._occluded_dist, self._inlier_dist)
150150

151-
def _get_mix_ratio(self, outlier_prob: float) -> FloatArray:
152-
return jnp.array((outlier_prob, 1 - outlier_prob))
151+
def _get_mix_ratio(self, occluded_prob: float) -> FloatArray:
152+
return jnp.array((occluded_prob, 1 - occluded_prob))
153153

154154
def sample(
155155
self,
156156
key: PRNGKey,
157157
latent_color: FloatArray,
158-
outlier_prob: float,
158+
occluded_prob: float,
159159
*args,
160160
**kwargs,
161161
) -> FloatArray:
162162
return PythonMixtureDistribution(self._mixture_dists).sample(
163-
key, self._get_mix_ratio(outlier_prob), [(), (latent_color,)]
163+
key, self._get_mix_ratio(occluded_prob), [(), (latent_color,)]
164164
)
165165

166166
def logpdf_per_channel(
167167
self,
168168
observed_color: FloatArray,
169169
latent_color: FloatArray,
170-
outlier_prob: float,
170+
occluded_prob: float,
171171
*args,
172172
**kwargs,
173173
) -> FloatArray:
174174
# Since the mixture model class does not keep the per-channel information,
175175
# we have to redefine this method to allow for testing
176176
logprobs = []
177-
for dist, prob in zip(self._mixture_dists, self._get_mix_ratio(outlier_prob)):
177+
for dist, prob in zip(self._mixture_dists, self._get_mix_ratio(occluded_prob)):
178178
logprobs.append(
179179
dist.logpdf_per_channel(observed_color, latent_color) + jnp.log(prob)
180180
)
@@ -192,7 +192,7 @@ class FullPixelColorDistribution(PixelColorDistribution):
192192
else:
193193
color ~ mixture(
194194
[uniform(0, 1), truncated_laplace(latent_color; color_scale)],
195-
[outlier_prob, 1 - outlier_prob]
195+
[occluded_prob, 1 - occluded_prob]
196196
)
197197
198198
Constructor args:
@@ -203,9 +203,9 @@ class FullPixelColorDistribution(PixelColorDistribution):
203203
- `latent_color`: 3-array. If no latent point hits the pixel, should contain
204204
3 negative values. If a latent point hits the pixel, should contain the point's
205205
color as an RGB value in [0, 1]^3.
206-
- `color_outlier_prob`: float. If a latent point hits the pixel, should contain
206+
- `color_occluded_prob`: float. If a latent point hits the pixel, should contain
207207
the probability associated with that point that the generated color is
208-
an outlier. If no latent point hits the pixel, this value is ignored.
208+
an occluded. If no latent point hits the pixel, this value is ignored.
209209
210210
Distribution support:
211211
- An RGB value in [0, 1]^3.
@@ -225,7 +225,7 @@ def sample(
225225
self,
226226
key: PRNGKey,
227227
latent_color: FloatArray,
228-
outlier_prob: FloatArray,
228+
occluded_prob: FloatArray,
229229
*args,
230230
**kwargs,
231231
) -> FloatArray:
@@ -236,14 +236,14 @@ def sample(
236236
# sample args
237237
key,
238238
latent_color,
239-
outlier_prob,
239+
occluded_prob,
240240
)
241241

242242
def logpdf_per_channel(
243243
self,
244244
observed_color: FloatArray,
245245
latent_color: FloatArray,
246-
outlier_prob: float,
246+
occluded_prob: float,
247247
*args,
248248
**kwargs,
249249
) -> FloatArray:
@@ -254,5 +254,5 @@ def logpdf_per_channel(
254254
# logpdf args
255255
observed_color,
256256
latent_color,
257-
outlier_prob,
257+
occluded_prob,
258258
)

src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class MixturePixelDepthDistribution(PixelDepthDistribution):
114114
"""A distribution that generates the depth of a pixel from
115115
mixture(
116116
[delta(-1), uniform(near, far), laplace(latent_depth; depth_scale)],
117-
[depth_nonreturn_prob, (1 - depth_nonreturn_prob) * outlier_prob, remaining_prob]
117+
[depth_nonreturn_prob, (1 - depth_nonreturn_prob) * occluded_prob, remaining_prob]
118118
)
119119
120120
The support of the distribution is [near, far] ∪ { "nonreturn" }.
@@ -129,7 +129,7 @@ def _nonreturn_dist(self) -> PixelDepthDistribution:
129129
return DeltaDistribution(DEPTH_NONRETURN_VAL)
130130

131131
@property
132-
def _outlier_dist(self) -> PixelDepthDistribution:
132+
def _occluded_dist(self) -> PixelDepthDistribution:
133133
return UniformPixelDepthDistribution(self.near, self.far)
134134

135135
@property
@@ -141,47 +141,47 @@ def _inlier_dist(self) -> PixelDepthDistribution:
141141
@property
142142
def _mixture_dist(self) -> PythonMixtureDistribution:
143143
return PythonMixtureDistribution(
144-
(self._nonreturn_dist, self._outlier_dist, self._inlier_dist)
144+
(self._nonreturn_dist, self._occluded_dist, self._inlier_dist)
145145
)
146146

147147
def _get_mix_ratio(
148-
self, depth_nonreturn_prob: float, outlier_prob: float
148+
self, occluded_prob: float, depth_nonreturn_prob: float
149149
) -> FloatArray:
150150
return jnp.array(
151151
(
152152
depth_nonreturn_prob,
153-
(1 - depth_nonreturn_prob) * outlier_prob,
154-
(1 - depth_nonreturn_prob) * (1 - outlier_prob),
153+
(1 - depth_nonreturn_prob) * occluded_prob,
154+
(1 - depth_nonreturn_prob) * (1 - occluded_prob),
155155
)
156156
)
157157

158158
def sample(
159159
self,
160160
key: PRNGKey,
161161
latent_depth: float,
162+
occluded_prob: float,
162163
depth_nonreturn_prob: float,
163-
outlier_prob: float,
164164
*args,
165165
**kwargs,
166166
) -> float:
167167
return self._mixture_dist.sample(
168168
key,
169-
self._get_mix_ratio(depth_nonreturn_prob, outlier_prob),
169+
self._get_mix_ratio(occluded_prob, depth_nonreturn_prob),
170170
[(), (), (latent_depth,)],
171171
)
172172

173173
def logpdf(
174174
self,
175175
observed_depth: float,
176176
latent_depth: float,
177+
occluded_prob: float,
177178
depth_nonreturn_prob: float,
178-
outlier_prob: float,
179179
*args,
180180
**kwargs,
181181
) -> float:
182182
return self._mixture_dist.logpdf(
183183
observed_depth,
184-
self._get_mix_ratio(depth_nonreturn_prob, outlier_prob),
184+
self._get_mix_ratio(occluded_prob, depth_nonreturn_prob),
185185
[(), (), (latent_depth,)],
186186
)
187187

@@ -255,7 +255,7 @@ class FullPixelDepthDistribution(PixelDepthDistribution):
255255
else:
256256
mixture(
257257
[delta(-1), uniform(near, far), laplace(latent_depth; depth_scale)],
258-
[depth_nonreturn_prob, (1 - depth_nonreturn_prob) * outlier_prob, remaining_prob]
258+
[depth_nonreturn_prob, (1 - depth_nonreturn_prob) * occluded_prob, remaining_prob]
259259
)
260260
"""
261261

@@ -275,8 +275,8 @@ def sample(
275275
self,
276276
key: PRNGKey,
277277
latent_depth: FloatArray,
278+
occluded_prob: FloatArray,
278279
depth_nonreturn_prob: float,
279-
outlier_prob: FloatArray,
280280
*args,
281281
**kwargs,
282282
) -> FloatArray:
@@ -287,16 +287,16 @@ def sample(
287287
# sample args
288288
key,
289289
latent_depth,
290+
occluded_prob,
290291
depth_nonreturn_prob,
291-
outlier_prob,
292292
)
293293

294294
def logpdf(
295295
self,
296296
observed_depth: FloatArray,
297297
latent_depth: FloatArray,
298+
occluded_prob: float,
298299
depth_nonreturn_prob: float,
299-
outlier_prob: float,
300300
*args,
301301
**kwargs,
302302
) -> FloatArray:
@@ -307,6 +307,6 @@ def logpdf(
307307
# logpdf args
308308
observed_depth,
309309
latent_depth,
310+
occluded_prob,
310311
depth_nonreturn_prob,
311-
outlier_prob,
312312
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import genjax
2+
import jax
3+
import jax.numpy as jnp
4+
from genjax import Pytree
5+
from genjax.typing import FloatArray, PRNGKey
6+
7+
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import PixelColorDistribution
8+
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import PixelDepthDistribution
9+
10+
11+
@Pytree.dataclass
12+
class PixelRGBDDistribution(genjax.ExactDensity):
13+
color_kernel: PixelColorDistribution
14+
depth_kernel: PixelDepthDistribution
15+
16+
def sample(
17+
self, key: PRNGKey, latent_rgbd: FloatArray, *args, **kwargs
18+
) -> FloatArray:
19+
keys = jax.random.split(key, 2)
20+
observed_color = self.color_kernel.sample(
21+
keys[0], latent_rgbd[:3], *args, **kwargs
22+
)
23+
observed_depth = self.depth_kernel.sample(
24+
keys[1], latent_rgbd[3], *args, **kwargs
25+
)
26+
return jnp.append(observed_color, observed_depth)
27+
28+
def logpdf(
29+
self, observed_rgbd: FloatArray, latent_rgbd: FloatArray, *args, **kwargs
30+
) -> float:
31+
color_logpdf = self.color_kernel.logpdf(
32+
observed_rgbd[:3], latent_rgbd[:3], *args, **kwargs
33+
)
34+
depth_logpdf = self.depth_kernel.logpdf(
35+
observed_rgbd[3], latent_rgbd[3], *args, **kwargs
36+
)
37+
return color_logpdf + depth_logpdf

tests/gen3d/test_pixel_color_kernels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def generate_color_grid(n_grid_steps: int):
3636
sample_kernels_to_test = [
3737
(UniformPixelColorDistribution(), ()),
3838
(TruncatedLaplacePixelColorDistribution(0.1), ()),
39-
(MixturePixelColorDistribution(0.3), (0.5,)), # outlier_prob
40-
(FullPixelColorDistribution(0.5), (0.3,)), # outlier_prob
39+
(MixturePixelColorDistribution(0.3), (0.5,)), # occluded_prob
40+
(FullPixelColorDistribution(0.5), (0.3,)), # occluded_prob
4141
]
4242

4343

@@ -80,15 +80,15 @@ def test_relative_logpdf():
8080
latent_color = -jnp.ones(3) # use -1 to denote invalid pixel
8181
logpdf_1 = kernel.logpdf(obs_color, latent_color, 0.2)
8282
logpdf_2 = kernel.logpdf(obs_color, latent_color, 0.8)
83-
# the logpdf should be the same because the outlier probability is not used
83+
# the logpdf should be the same because the occluded probability is not used
8484
# in the case when no color hit the pixel
8585
assert jnp.allclose(logpdf_1, logpdf_2)
8686

8787
# case 2: a color hit the pixel, but the color is not close to the observed color
8888
latent_color = jnp.array([1.0, 0.5, 0.0])
8989
logpdf_3 = kernel.logpdf(obs_color, latent_color, 0.2)
9090
logpdf_4 = kernel.logpdf(obs_color, latent_color, 0.8)
91-
# the pixel should be more likely to be an outlier
91+
# the pixel should be more likely to be an occluded
9292
assert logpdf_3 < logpdf_4
9393

9494
# case 3: a color hit the pixel, and the color is close to the observed color

tests/gen3d/test_pixel_depth_kernels.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
(
2323
MixturePixelDepthDistribution(near, far, 0.15),
2424
(
25+
0.5, # occluded_prob
2526
0.23, # depth_nonreturn_prob
26-
0.5, # outlier_prob
2727
),
2828
),
2929
(
3030
FullPixelDepthDistribution(near, far, 0.5),
3131
(
32+
0.3, # occluded_prob
3233
0.1, # depth_nonreturn_prob
33-
0.3, # outlier_prob
3434
),
3535
),
3636
]
@@ -81,26 +81,26 @@ def test_relative_logpdf():
8181
obs_depth = DEPTH_NONRETURN_VAL
8282
latent_depth = DEPTH_NONRETURN_VAL
8383
depth_nonreturn_prob = 0.2
84-
logpdf_1 = kernel.logpdf(obs_depth, latent_depth, depth_nonreturn_prob, 0.8)
84+
logpdf_1 = kernel.logpdf(obs_depth, latent_depth, 0.8, depth_nonreturn_prob)
8585
assert logpdf_1 == jnp.log(depth_nonreturn_prob)
8686

8787
latent_depth = -1.0 # no depth information from latent
88-
logpdf_2 = kernel.logpdf(obs_depth, latent_depth, depth_nonreturn_prob, 0.8)
88+
logpdf_2 = kernel.logpdf(obs_depth, latent_depth, 0.8, depth_nonreturn_prob)
8989
# nonreturn obs cannot be generates from latent that is not nonreturn
9090
assert logpdf_2 == jnp.log(UNEXPLAINED_DEPTH_NONRETURN_PROB)
9191

9292
# case 2: valid depth is observed, but latent depth is far from the observed depth
9393
obs_depth = 10.0
9494
latent_depth = 0.01
95-
logpdf_3 = kernel.logpdf(obs_depth, latent_depth, depth_nonreturn_prob, 0.9)
96-
logpdf_4 = kernel.logpdf(obs_depth, latent_depth, depth_nonreturn_prob, 0.1)
97-
# the pixel should be more likely to be an outlier
95+
logpdf_3 = kernel.logpdf(obs_depth, latent_depth, 0.9, depth_nonreturn_prob)
96+
logpdf_4 = kernel.logpdf(obs_depth, latent_depth, 0.1, depth_nonreturn_prob)
97+
# the pixel should be more likely to be an occluded
9898
assert logpdf_3 > logpdf_4
9999

100100
# case 3: valid depth is observed, but latent depth is close from the observed depth
101101
obs_depth = 6.0
102102
latent_depth = 6.01
103-
logpdf_5 = kernel.logpdf(obs_depth, latent_depth, depth_nonreturn_prob, 0.9)
104-
logpdf_6 = kernel.logpdf(obs_depth, latent_depth, depth_nonreturn_prob, 0.1)
103+
logpdf_5 = kernel.logpdf(obs_depth, latent_depth, 0.9, depth_nonreturn_prob)
104+
logpdf_6 = kernel.logpdf(obs_depth, latent_depth, 0.1, depth_nonreturn_prob)
105105
# the pixel should be more likely to be an inliner
106106
assert logpdf_5 < logpdf_6

0 commit comments

Comments
 (0)