Skip to content

Commit

Permalink
feat(overcooked): Add highlight mask to image
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Jun 30, 2024
1 parent 59e4a96 commit 4a78659
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
7 changes: 3 additions & 4 deletions jaxmarl/viz/grid_rendering_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def highlight_img(img, color=(255, 255, 255), alpha=0.30):
"""
Add highlighting to an image
"""
blend_img = img + alpha * (jnp.array(color, dtype=jnp.uint8) - img)
blend_img = jnp.clip(blend_img, 0, 255).astype(jnp.uint8)
img = img.at[:, :, :].set(blend_img)
return img
blend_img = alpha * (jnp.array(color, dtype=jnp.uint8) - img)
res_img = jnp.clip(img + blend_img, 0, 255).astype(jnp.uint8)
return res_img
38 changes: 20 additions & 18 deletions jaxmarl/viz/overcooked_v2_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,24 @@ def _include_agents(grid, agent):
grid = grid.at[:, :, 1].set(new_ingredients_layer)

highlight_mask = jnp.zeros(grid.shape[:2], dtype=bool)
# if self.agent_view_size:
# for x, y in zip(agents.pos.x, agents.pos.y):
# x_low, x_high, y_low, y_high = compute_view_box(
# x, y, self.agent_view_size, grid.shape[0], grid.shape[1]
# )
# # highlight_mask[y_low:y_high, x_low:x_high] = True
# highlight_mask = highlight_mask.at[y_low:y_high, x_low:x_high].set(True)
if agent_view_size:
for x, y in zip(agents.pos.x, agents.pos.y):
x_low, x_high, y_low, y_high = compute_view_box(
x, y, agent_view_size, grid.shape[0], grid.shape[1]
)

row_mask = jnp.arange(grid.shape[0])
col_mask = jnp.arange(grid.shape[1])

row_mask = (row_mask >= y_low) & (row_mask < y_high)
col_mask = (col_mask >= x_low) & (col_mask < x_high)

agent_mask = row_mask[:, None] & col_mask[None, :]

highlight_mask |= agent_mask

# Render the whole grid
img = self._render_grid(
grid,
highlight_mask=highlight_mask,
)
img = self._render_grid(grid, highlight_mask)
return img

@staticmethod
Expand Down Expand Up @@ -391,8 +396,8 @@ def _render_tile(

img = OvercookedV2Visualizer._render_cell(obj, img)

# if highlight:
# rendering.highlight_img(img)
img_highlight = rendering.highlight_img(img, highlight)
img = jax.lax.select(highlight, img_highlight, img)

# Downsample the image to perform supersampling/anti-aliasing
img = rendering.downsample(img, self.subdivs)
Expand All @@ -405,12 +410,9 @@ def _render_tile(
def _render_grid(
self,
grid,
highlight_mask=None,
highlight_mask,
):
if highlight_mask is None:
highlight_mask = jnp.zeros(shape=grid.shape[:2], dtype=bool)

img_grid = jax.vmap(jax.vmap(self._render_tile))(grid)
img_grid = jax.vmap(jax.vmap(self._render_tile))(grid, highlight_mask)

print("img_grid", img_grid.shape)

Expand Down

0 comments on commit 4a78659

Please sign in to comment.