Skip to content

Commit 007915b

Browse files
authored
Update b3d to GenJAX 0.6.1 (GEN-561) (#158)
This PR: - upgrades GenJAX to 0.6.1 - massages a few `ChoiceMap` construction calls to use nicer, newer methods - replaces broken `.const` calls on `Const` objects with `unwrap()`. `val` would work too, but this is the preferred way. See the release notes here: https://github.com/probcomp/genjax/releases/tag/v0.6.0 and here because I fumbled the release and forgot a commit: https://github.com/probcomp/genjax/releases/tag/v0.6.1
1 parent 01e917c commit 007915b

File tree

16 files changed

+721
-699
lines changed

16 files changed

+721
-699
lines changed

notebooks/aug1demos/slam_color_room.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1621,7 +1621,7 @@
16211621
"# convert_rgbd_to_color_space = lambda x: b3d.colors.rgbd_to_labd(x)\n",
16221622
"# convert_color_space_to_rgbd = lambda x: b3d.colors.labd_to_rgbd(x)\n",
16231623
"def intermediate_likelihood_func(observed_rgbd, latent_rgbd, likelihood_args):\n",
1624-
" k = likelihood_args[\"k\"].const\n",
1624+
" k = likelihood_args[\"k\"].unwrap()\n",
16251625
" fx = likelihood_args[\"fx\"]\n",
16261626
" fy = likelihood_args[\"fy\"]\n",
16271627
"\n",

notebooks/bayes3d_paper/interactive.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
"# def sample_likelihood_func()\n",
170170
"\n",
171171
"def intermediate_likelihood_func(observed_rgbd, latent_rgbd, likelihood_args):\n",
172-
" k = likelihood_args[\"k\"].const\n",
172+
" k = likelihood_args[\"k\"].unwrap()\n",
173173
" fx = likelihood_args[\"fx\"]\n",
174174
" fy = likelihood_args[\"fy\"]\n",
175175
" \n",
@@ -628,16 +628,16 @@
628628
" rendered_rgbd = renderer.render_rgbd_from_mesh(meshes[IDX].transform(pose))\n",
629629
" rendered_color_space_d = convert_rgbd_to_color_space(rendered_rgbd)\n",
630630
"\n",
631-
" k = likelihood_args[\"k\"].const\n",
631+
" k = likelihood_args[\"k\"].unwrap()\n",
632632
" image_height, image_width = rendered_color_space_d.shape[0], rendered_color_space_d.shape[1]\n",
633633
" image_height = Pytree.const(image_height)\n",
634634
" image_width = Pytree.const(image_width)\n",
635635
"\n",
636636
" row_coordinates = genjax.categorical.vmap(in_axes=(0,))(\n",
637-
" jnp.ones((k, image_height.const))\n",
637+
" jnp.ones((k, image_height.unwrap()))\n",
638638
" ) @ \"row_coordinates\"\n",
639639
" column_coordinates = genjax.categorical.vmap(in_axes=(0,))(\n",
640-
" jnp.ones((k, image_width.const))\n",
640+
" jnp.ones((k, image_width.unwrap()))\n",
641641
" ) @ \"column_coordinates\"\n",
642642
"\n",
643643
"\n",

notebooks/bayes3d_paper/run_ycbv_evaluation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ def _gvmf_and_select_best_move(
108108
Pose.sample_gaussian_vmf_pose, in_axes=(0, None, None, None)
109109
)(
110110
jax.random.split(key, number),
111-
trace.get_choices()[address.const],
111+
trace.get_choices()[address.unwrap()],
112112
variance,
113113
concentration,
114114
),
115-
trace.get_choices()[address.const][None, ...],
115+
trace.get_choices()[address.unwrap()][None, ...],
116116
]
117117
)
118118
scores = jax.vmap(update_pose_and_color, in_axes=(None, None, 0))(

notebooks/bayes3d_paper/ycbv.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@
346346
"@genjax.gen\n",
347347
"def dense_multiobject_model(num_objects, meshes, likelihood_args):\n",
348348
" all_poses = []\n",
349-
" for i in range(num_objects.const):\n",
349+
" for i in range(num_objects.unwrap()):\n",
350350
" object_pose = (\n",
351351
" uniform_pose(jnp.ones(3) * -100.0, jnp.ones(3) * 100.0) @ f\"object_pose_{i}\"\n",
352352
" )\n",

notebooks/integration.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@
534534
"metadata": {},
535535
"outputs": [],
536536
"source": [
537-
"scantr = ps.particle_system_state_step.scan(n=(num_timesteps.const - 1)).simulate(\n",
537+
"scantr = ps.particle_system_state_step.scan(n=(num_timesteps.unwrap() - 1)).simulate(\n",
538538
" key, (state0, None)\n",
539539
")"
540540
]

pixi.lock

Lines changed: 674 additions & 652 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ cwd = "scripts"
109109
[tool.pixi.feature.core.pypi-dependencies]
110110
carvekit = "==4.1.2"
111111
datasync = "==0.0.2"
112-
genjax = "==0.5.1"
112+
genjax = "==0.6.1"
113113
pykitti = "==0.3.1"
114114
pyliblzfse = { git = "https://github.com/ydkhatri/pyliblzfse.git" }
115115
pyransac3d = ">=0.6.0,<0.7"

src/b3d/chisight/dense/dense_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def make_dense_multiobject_model(renderer, likelihood_func, sample_func=None):
1616
def f(key, likelihood_args):
1717
return jnp.zeros(
1818
(
19-
likelihood_args["image_height"].const,
20-
likelihood_args["image_width"].const,
19+
likelihood_args["image_height"].unwrap(),
20+
likelihood_args["image_width"].unwrap(),
2121
4,
2222
)
2323
)
@@ -44,7 +44,7 @@ def dense_multiobject_model(args_dict):
4444
likelihood_args["blur"] = blur
4545

4646
all_poses = []
47-
for i in range(num_objects.const):
47+
for i in range(num_objects.unwrap()):
4848
object_pose = (
4949
uniform_pose(jnp.ones(3) * -100.0, jnp.ones(3) * 100.0)
5050
@ f"object_pose_{i}"

src/b3d/chisight/dense/likelihoods/blur_likelihood_gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def likelihood_per_pixel(latent_rgbd: jnp.ndarray, blur):
114114

115115
@jax.jit
116116
def blur_intermediate_likelihood_func(observed_rgbd, latent_rgbd, likelihood_args):
117-
# k = likelihood_args["k"].const
117+
# k = likelihood_args["k"].unwrap()
118118
color_variance = likelihood_args["color_variance_0"]
119119
depth_variance = likelihood_args["depth_variance_0"]
120120
outlier_probability = likelihood_args["outlier_probability_0"]

src/b3d/chisight/dynamic_object_model/likelihoods/aggreate_mean_image_kernel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def sample_func(key, args):
2424
).astype(jnp.int32)
2525

2626
latent_image_sum = jnp.zeros(
27-
(args["image_height"].const, args["image_width"].const, 4)
27+
(args["image_height"].unwrap(), args["image_width"].unwrap(), 4)
2828
)
2929
latent_image_sum = latent_image_sum.at[pixels[..., 0], pixels[..., 1], :3].add(
3030
args["colors"] * (1 - args["color_outlier_probability"])[:, None]
@@ -34,14 +34,14 @@ def sample_func(key, args):
3434
)
3535

3636
projected_points_count = jnp.zeros(
37-
(args["image_height"].const, args["image_width"].const)
37+
(args["image_height"].unwrap(), args["image_width"].unwrap())
3838
)
3939
projected_points_count = projected_points_count.at[
4040
pixels[..., 0], pixels[..., 1]
4141
].add(1)
4242

4343
non_registration_probability = jnp.ones(
44-
(args["image_height"].const, args["image_width"].const)
44+
(args["image_height"].unwrap(), args["image_width"].unwrap())
4545
)
4646
non_registration_probability = non_registration_probability.at[
4747
pixels[..., 0], pixels[..., 1]
@@ -81,7 +81,7 @@ def likelihood_func(observed_rgbd, args):
8181
).astype(jnp.int32)
8282

8383
latent_image_sum = jnp.zeros(
84-
(args["image_height"].const, args["image_width"].const, 4)
84+
(args["image_height"].unwrap(), args["image_width"].unwrap(), 4)
8585
)
8686
latent_image_sum = latent_image_sum.at[pixels[..., 0], pixels[..., 1], :3].add(
8787
args["colors"] * (1 - args["color_outlier_probability"])[:, None]
@@ -91,14 +91,14 @@ def likelihood_func(observed_rgbd, args):
9191
)
9292

9393
projected_points_count = jnp.zeros(
94-
(args["image_height"].const, args["image_width"].const)
94+
(args["image_height"].unwrap(), args["image_width"].unwrap())
9595
)
9696
projected_points_count = projected_points_count.at[
9797
pixels[..., 0], pixels[..., 1]
9898
].add(1)
9999

100100
non_registration_probability = jnp.ones(
101-
(args["image_height"].const, args["image_width"].const)
101+
(args["image_height"].unwrap(), args["image_width"].unwrap())
102102
)
103103
non_registration_probability = non_registration_probability.at[
104104
pixels[..., 0], pixels[..., 1]

0 commit comments

Comments
 (0)