diff --git a/jaxmarl/viz/grid_rendering.py b/jaxmarl/viz/grid_rendering.py index 061d42c54..8c7fc2bfb 100644 --- a/jaxmarl/viz/grid_rendering.py +++ b/jaxmarl/viz/grid_rendering.py @@ -12,8 +12,8 @@ def downsample(img, factor): img = img.reshape( [img.shape[0] // factor, factor, img.shape[1] // factor, factor, 3] ) - img = img.mean(axis=3) - img = img.mean(axis=1) + img = img.mean(axis=3, dtype=jnp.uint8) + img = img.mean(axis=1, dtype=jnp.uint8) return img diff --git a/jaxmarl/viz/overcooked_v2_visualizer.py b/jaxmarl/viz/overcooked_v2_visualizer.py index 193679710..10b2dec97 100644 --- a/jaxmarl/viz/overcooked_v2_visualizer.py +++ b/jaxmarl/viz/overcooked_v2_visualizer.py @@ -12,19 +12,19 @@ TILE_PIXELS = 32 COLORS = { - "red": jnp.array([255, 0, 0]), - "green": jnp.array([0, 255, 0]), - "blue": jnp.array([0, 0, 255]), - "purple": jnp.array([112, 39, 195]), - "yellow": jnp.array([255, 255, 0]), - "grey": jnp.array([100, 100, 100]), - "white": jnp.array([255, 255, 255]), - "black": jnp.array([25, 25, 25]), - "orange": jnp.array([230, 180, 0]), - "pink": jnp.array([255, 105, 180]), - "brown": jnp.array([139, 69, 19]), - "cyan": jnp.array([0, 255, 255]), - "light_blue": jnp.array([173, 216, 230]), + "red": jnp.array([255, 0, 0], dtype=jnp.uint8), + "green": jnp.array([0, 255, 0], dtype=jnp.uint8), + "blue": jnp.array([0, 0, 255], dtype=jnp.uint8), + "purple": jnp.array([112, 39, 195], dtype=jnp.uint8), + "yellow": jnp.array([255, 255, 0], dtype=jnp.uint8), + "grey": jnp.array([100, 100, 100], dtype=jnp.uint8), + "white": jnp.array([255, 255, 255], dtype=jnp.uint8), + "black": jnp.array([25, 25, 25], dtype=jnp.uint8), + "orange": jnp.array([230, 180, 0], dtype=jnp.uint8), + "pink": jnp.array([255, 105, 180], dtype=jnp.uint8), + "brown": jnp.array([139, 69, 19], dtype=jnp.uint8), + "cyan": jnp.array([0, 255, 255], dtype=jnp.uint8), + "light_blue": jnp.array([173, 216, 230], dtype=jnp.uint8), } INGREDIENT_COLORS = jnp.array( @@ -49,11 +49,12 @@ class OvercookedV2Visualizer: tile_cache = {} - def __init__(self, agent_view_size=None, tile_size=TILE_PIXELS): + def __init__(self, agent_view_size=None, tile_size=TILE_PIXELS, subdivs=3): self.window = None self.agent_view_size = agent_view_size self.tile_size = tile_size + self.subdivs = subdivs def _lazy_init_window(self): if self.window is None: @@ -83,6 +84,9 @@ def animate(self, state_seq, filename="animation.gif"): # frame_seq = [get_frame(state) for state in state_seq] frame_seq = jax.vmap(self._render_state)(state_seq) + # print("frame_seq", frame_seq) + print("frame_seq.shape", frame_seq.shape) + print("frame_seq.dtype", frame_seq.dtype) imageio.mimsave(filename, frame_seq, "GIF", duration=0.5) @@ -126,10 +130,9 @@ def _include_agents(grid, agent): # highlight_mask = highlight_mask.at[y_low:y_high, x_low:x_high].set(True) # Render the whole grid - img = OvercookedV2Visualizer._render_grid( + img = self._render_grid( grid, highlight_mask=highlight_mask, - tile_size=self.tile_size, ) return img @@ -181,7 +184,7 @@ def _render_wall(cell, img): img = rendering.fill_coords( img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"] ) - img = OvercookedV2Visualizer._render_counter(cell[1], img) + img = OvercookedV2Visualizer._render_dynamic_item(cell[1], img) return img @@ -199,12 +202,19 @@ def _render_agent(cell, img): ) img = rendering.fill_coords(img, tri_fn, COLORS["red"]) - img = OvercookedV2Visualizer._render_inv(cell[1], img) + img = OvercookedV2Visualizer._render_dynamic_item( + cell[1], + img, + plate_fn=rendering.point_in_circle(0.75, 0.75, 0.2), + ingredient_fn=rendering.point_in_circle(0.75, 0.75, 0.15), + dish_positions=jnp.array([(0.65, 0.65), (0.85, 0.65), (0.75, 0.85)]), + ) return img def _render_agent_self(cell, img): - raise NotImplementedError() + # Note: This should not ever be called + return img def _render_goal(cell, img): img = rendering.fill_coords( @@ -226,7 +236,7 @@ def _render_recipe_indicator(cell, img): img = rendering.fill_coords( img, rendering.point_in_rect(0.1, 0.9, 0.1, 0.9), COLORS["brown"] ) - img = OvercookedV2Visualizer._render_counter(cell[1], img) + img = OvercookedV2Visualizer._render_dynamic_item(cell[1], img) return img @@ -284,6 +294,9 @@ def _render_ingredient_pile(cell, img): branch_idx = jnp.clip(static_object, 0, len(render_fns) - 1) + print("branch_idx", branch_idx) + print("render_fns", render_fns) + return jax.lax.switch( branch_idx, render_fns, @@ -291,42 +304,6 @@ def _render_ingredient_pile(cell, img): img, ) - @staticmethod - def _render_counter(ingredients, img): - plate_fn = rendering.point_in_circle(0.5, 0.5, 0.3) - img_plate = rendering.fill_coords(img, plate_fn, COLORS["white"]) - img = jax.lax.select(ingredients & DynamicObject.PLATE, img_plate, img) - - # if DynamicObject.is_ingredient(ingredients): - idx = DynamicObject.get_ingredient_idx(ingredients) - ingredient_fn = rendering.point_in_circle(0.5, 0.5, 0.15) - img_ing = rendering.fill_coords(img, ingredient_fn, INGREDIENT_COLORS[idx]) - img = jax.lax.select(DynamicObject.is_ingredient(ingredients), img_ing, img) - - # if ingredients & DynamicObject.COOKED: - def _render_cooked_ingredient(img, x): - idx, ingredient_idx = x - - positions = jnp.array([(0.5, 0.4), (0.4, 0.6), (0.6, 0.6)]) - - color = INGREDIENT_COLORS[ingredient_idx] - # pos = positions[jnp.minimum(idx, len(positions) - 1)] - pos = positions[idx] - ingredient_fn = rendering.point_in_circle(pos[0], pos[1], 0.1) - img_ing = rendering.fill_coords(img, ingredient_fn, color) - - img = jax.lax.select(ingredient_idx != -1, img_ing, img) - return img, None - - ingredient_indices = DynamicObject.get_ingredient_idx_list_jit(ingredients) - img, _ = jax.lax.scan( - _render_cooked_ingredient, - img, - (jnp.arange(len(ingredient_indices)), ingredient_indices), - ) - - return img - @staticmethod def _render_pot(cell, img): ingredients = cell[1] @@ -334,77 +311,57 @@ def _render_pot(cell, img): is_cooking = time_left > 0 is_cooked = (ingredients & DynamicObject.COOKED) != 0 - is_idle = not is_cooking and not is_cooked - ingredients = DynamicObject.get_ingredient_idx_list(ingredients) - - pot_fn = rendering.point_in_rect(0.1, 0.9, 0.33, 0.9) - lid_fn = rendering.point_in_rect(0.1, 0.9, 0.21, 0.25) - handle_fn = rendering.point_in_rect(0.4, 0.6, 0.16, 0.21) + is_idle = ~is_cooking & ~is_cooked + ingredients = DynamicObject.get_ingredient_idx_list_jit(ingredients) + has_ingredients = ingredients[0] != -1 rendering.fill_coords(img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) - if len(ingredients) > 0: - ingredient_fns = [ - rendering.point_in_circle(*coord, 0.13) - for coord in [(0.23, 0.33), (0.77, 0.33), (0.50, 0.33)] - ] - [ - rendering.fill_coords( - img, ingredient_fns[i], INGREDIENT_COLORS[ingredient_idx] - ) - for i, ingredient_idx in enumerate(ingredients) - ] + # if len(ingredients) > 0: + ingredient_fns = [ + rendering.point_in_circle(*coord, 0.13) + for coord in [(0.23, 0.33), (0.77, 0.33), (0.50, 0.33)] + ] - if len(ingredients) > 0 and is_idle: - lid_fn = rendering.rotate_fn(lid_fn, cx=0.1, cy=0.25, theta=-0.1 * math.pi) - handle_fn = rendering.rotate_fn( - handle_fn, cx=0.1, cy=0.25, theta=-0.1 * math.pi + for i, ingredient_idx in enumerate(ingredients): + img_ing = rendering.fill_coords( + img, ingredient_fns[i], INGREDIENT_COLORS[ingredient_idx] ) + img = jax.lax.select(ingredient_idx != -1, img_ing, img) - # if is_cooked: - # TODO: maybe make it more obvious that the dish is cooked + pot_fn = rendering.point_in_rect(0.1, 0.9, 0.33, 0.9) + lid_fn = rendering.point_in_rect(0.1, 0.9, 0.21, 0.25) + handle_fn = rendering.point_in_rect(0.4, 0.6, 0.16, 0.21) - # Render the pot itself - pot_fns = [pot_fn, lid_fn, handle_fn] - [rendering.fill_coords(img, pot_fn, COLORS["black"]) for pot_fn in pot_fns] + lid_fn_open = rendering.rotate_fn(lid_fn, cx=0.1, cy=0.25, theta=-0.1 * math.pi) + handle_fn_open = rendering.rotate_fn( + handle_fn, cx=0.1, cy=0.25, theta=-0.1 * math.pi + ) + pot_open = is_idle & has_ingredients - # Render progress bar - if is_cooking: - progress_fn = rendering.point_in_rect( - 0.1, 0.9 - (0.9 - 0.1) / POT_COOK_TIME * time_left, 0.83, 0.88 - ) - rendering.fill_coords(img, progress_fn, COLORS["green"]) + img = rendering.fill_coords(img, pot_fn, COLORS["black"]) - @staticmethod - def _render_inv(ingredients, img): - # print("ingredients: ", ingredients) - # if DynamicObject.is_ingredient(ingredients): - idx = DynamicObject.get_ingredient_idx(ingredients) - ingredient_fn = rendering.point_in_circle(0.75, 0.75, 0.15) - img_ing = rendering.fill_coords(img, ingredient_fn, INGREDIENT_COLORS[idx]) - img = jax.lax.select(DynamicObject.is_ingredient(ingredients), img_ing, img) + img_closed = rendering.fill_coords(img, lid_fn, COLORS["black"]) + img_closed = rendering.fill_coords(img_closed, handle_fn, COLORS["black"]) - # if ingredients & DynamicObject.PLATE: - plate_fn = rendering.point_in_circle(0.75, 0.75, 0.2) - img_plate = rendering.fill_coords(img, plate_fn, COLORS["white"]) - img = jax.lax.select(ingredients & DynamicObject.PLATE, img_plate, img) + img_open = rendering.fill_coords(img, lid_fn_open, COLORS["black"]) + img_open = rendering.fill_coords(img_open, handle_fn_open, COLORS["black"]) - # if ingredients & DynamicObject.COOKED: - positions = [(0.65, 0.65), (0.85, 0.65), (0.75, 0.85)] - ingredient_indices = DynamicObject.get_ingredient_idx_list(ingredients) + img = jax.lax.select(pot_open, img_open, img_closed) - for idx, ingredient_idx in enumerate(ingredient_indices): - color = INGREDIENT_COLORS[ingredient_idx] - pos = positions[min(idx, len(positions) - 1)] - ingredient_fn = rendering.point_in_circle(pos[0], pos[1], 0.10) - rendering.fill_coords(img, ingredient_fn, color) + # Render progress bar + progress_fn = rendering.point_in_rect( + 0.1, 0.9 - (0.9 - 0.1) / POT_COOK_TIME * time_left, 0.83, 0.88 + ) + img_timer = rendering.fill_coords(img, progress_fn, COLORS["green"]) + img = jax.lax.select(is_cooking, img_timer, img) + + return img - @staticmethod def _render_tile( + self, obj, highlight=False, - tile_size=TILE_PIXELS, - subdivs=3, ): """ Render a tile and cache the result @@ -415,65 +372,84 @@ def _render_tile( # return OvercookedV2Visualizer.tile_cache[key] img = jnp.zeros( - shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=jnp.uint8 + shape=(self.tile_size * self.subdivs, self.tile_size * self.subdivs, 3), + dtype=jnp.uint8, ) # Draw the grid lines (top and left edges) rendering.fill_coords( - img, rendering.point_in_rect(0, 0.031, 0, 1), jnp.array([100, 100, 100]) + img, rendering.point_in_rect(0, 0.031, 0, 1), COLORS["grey"] ) rendering.fill_coords( - img, rendering.point_in_rect(0, 1, 0, 0.031), jnp.array([100, 100, 100]) + img, rendering.point_in_rect(0, 1, 0, 0.031), COLORS["grey"] ) - OvercookedV2Visualizer._render_cell(obj, img) + img = OvercookedV2Visualizer._render_cell(obj, img) + + # if highlight: + # rendering.highlight_img(img) - if highlight: - rendering.highlight_img(img) + print("img", img.shape) + print("img.dtype", img.dtype) # Downsample the image to perform supersampling/anti-aliasing - img = rendering.downsample(img, subdivs) + img = rendering.downsample(img, self.subdivs) + + print("img", img.shape) + print("img.dtype", img.dtype) # Cache the rendered tile # OvercookedV2Visualizer.tile_cache[key] = img return img - @staticmethod def _render_grid( + self, grid, highlight_mask=None, - tile_size=TILE_PIXELS, ): if highlight_mask is None: highlight_mask = jnp.zeros(shape=grid.shape[:2], dtype=bool) - # Compute the total grid size in pixels - width_px = grid.shape[1] * tile_size - height_px = grid.shape[0] * tile_size - - img = jnp.zeros(shape=(height_px, width_px, 3), dtype=jnp.uint8) - - def _set_tile(x, y, tile_img): - ymin = y * tile_size - ymax = (y + 1) * tile_size - xmin = x * tile_size - xmax = (x + 1) * tile_size - img[ymin:ymax, xmin:xmax, :] = tile_img - - # Render the grid - for y in range(grid.shape[0]): - for x in range(grid.shape[1]): - cell = grid[y, x] - tile_img = OvercookedV2Visualizer._render_tile( - cell, - highlight=highlight_mask[y, x], - tile_size=tile_size, - ) + # # Compute the total grid size in pixels + # width_px = grid.shape[1] * tile_size + # height_px = grid.shape[0] * tile_size + + # img = jnp.zeros(shape=(height_px, width_px, 3), dtype=jnp.uint8) + + # def _set_tile(x, y, tile_img): + # ymin = y * tile_size + # ymax = (y + 1) * tile_size + # xmin = x * tile_size + # xmax = (x + 1) * tile_size + # img[ymin:ymax, xmin:xmax, :] = tile_img + + # # Render the grid + # for y in range(grid.shape[0]): + # for x in range(grid.shape[1]): + # cell = grid[y, x] + # tile_img = OvercookedV2Visualizer._render_tile( + # cell, + # highlight=highlight_mask[y, x], + # tile_size=tile_size, + # ) - _set_tile(x, y, tile_img) + # _set_tile(x, y, tile_img) - return img + img_grid = jax.vmap(jax.vmap(self._render_tile))(grid) + + print("img_grid", img_grid.shape) + + # img = img_grid + + grid_rows, grid_cols, tile_height, tile_width, channels = img_grid.shape + + # Reshape and transpose to merge the grid + big_image = img_grid.transpose(0, 2, 1, 3, 4).reshape( + grid_rows * tile_height, grid_cols * tile_width, channels + ) + + return big_image def close(self): self.window.close()