diff --git a/jaxmarl/viz/overcooked_v2_visualizer.py b/jaxmarl/viz/overcooked_v2_visualizer.py index 10b2dec9..2031e068 100644 --- a/jaxmarl/viz/overcooked_v2_visualizer.py +++ b/jaxmarl/viz/overcooked_v2_visualizer.py @@ -1,7 +1,7 @@ import math from jaxmarl.environments.overcooked_v2.utils import compute_view_box from jaxmarl.viz.window import Window -import jaxmarl.viz.grid_rendering as rendering +import jaxmarl.viz.grid_rendering_v2 as rendering import jax import jax.numpy as jnp from jaxmarl.environments.overcooked_v2.common import StaticObject, DynamicObject @@ -49,10 +49,9 @@ class OvercookedV2Visualizer: tile_cache = {} - def __init__(self, agent_view_size=None, tile_size=TILE_PIXELS, subdivs=3): + def __init__(self, tile_size=TILE_PIXELS, subdivs=3): self.window = None - self.agent_view_size = agent_view_size self.tile_size = tile_size self.subdivs = subdivs @@ -64,34 +63,28 @@ def show(self, block=False): self._lazy_init_window() self.window.show(block=block) - def render(self, state): + def render(self, state, agent_view_size=None): """Method for rendering the state in a window. Esp. useful for interactive mode.""" self._lazy_init_window() - img = self._render_state(state) + img = self._render_state(state, agent_view_size) self.window.show_img(img) - def animate(self, state_seq, filename="animation.gif"): + def animate(self, state_seq, filename="animation.gif", agent_view_size=None): """Animate a gif give a state sequence and save if to file.""" - # def get_frame(state): - # frame = OvercookedV2Visualizer._render_state( - # state, agent_view_size=agent_view_size - # ) - # return frame - - # frame_seq = [get_frame(state) for state in state_seq] - - frame_seq = jax.vmap(self._render_state)(state_seq) + frame_seq = jax.vmap(self._render_state, in_axes=(0, None))( + state_seq, agent_view_size + ) # print("frame_seq", frame_seq) print("frame_seq.shape", frame_seq.shape) print("frame_seq.dtype", frame_seq.dtype) imageio.mimsave(filename, frame_seq, "GIF", duration=0.5) - @partial(jax.jit, static_argnums=(0,)) - def _render_state(self, state): + @partial(jax.jit, static_argnums=(0, 2)) + def _render_state(self, state, agent_view_size=None): """ Render the state """ @@ -287,16 +280,13 @@ def _render_ingredient_pile(cell, img): StaticObject.PLATE_PILE: _render_plate_pile, } - render_fns = [_render_empty] * (max(render_fns_dict.keys()) + 1) + render_fns = [_render_empty] * (max(render_fns_dict.keys()) + 2) for key, value in render_fns_dict.items(): render_fns[key] = value render_fns[-1] = _render_ingredient_pile branch_idx = jnp.clip(static_object, 0, len(render_fns) - 1) - print("branch_idx", branch_idx) - print("render_fns", render_fns) - return jax.lax.switch( branch_idx, render_fns, @@ -315,9 +305,10 @@ def _render_pot(cell, img): ingredients = DynamicObject.get_ingredient_idx_list_jit(ingredients) has_ingredients = ingredients[0] != -1 - rendering.fill_coords(img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + img = rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"] + ) - # if len(ingredients) > 0: ingredient_fns = [ rendering.point_in_circle(*coord, 0.13) for coord in [(0.23, 0.33), (0.77, 0.33), (0.50, 0.33)] @@ -377,10 +368,10 @@ def _render_tile( ) # Draw the grid lines (top and left edges) - rendering.fill_coords( + img = rendering.fill_coords( img, rendering.point_in_rect(0, 0.031, 0, 1), COLORS["grey"] ) - rendering.fill_coords( + img = rendering.fill_coords( img, rendering.point_in_rect(0, 1, 0, 0.031), COLORS["grey"] ) @@ -389,15 +380,9 @@ def _render_tile( # if highlight: # rendering.highlight_img(img) - print("img", img.shape) - print("img.dtype", img.dtype) - # Downsample the image to perform supersampling/anti-aliasing img = rendering.downsample(img, self.subdivs) - print("img", img.shape) - print("img.dtype", img.dtype) - # Cache the rendered tile # OvercookedV2Visualizer.tile_cache[key] = img @@ -411,44 +396,18 @@ def _render_grid( if highlight_mask is None: highlight_mask = jnp.zeros(shape=grid.shape[:2], dtype=bool) - # # Compute the total grid size in pixels - # width_px = grid.shape[1] * tile_size - # height_px = grid.shape[0] * tile_size - - # img = jnp.zeros(shape=(height_px, width_px, 3), dtype=jnp.uint8) - - # def _set_tile(x, y, tile_img): - # ymin = y * tile_size - # ymax = (y + 1) * tile_size - # xmin = x * tile_size - # xmax = (x + 1) * tile_size - # img[ymin:ymax, xmin:xmax, :] = tile_img - - # # Render the grid - # for y in range(grid.shape[0]): - # for x in range(grid.shape[1]): - # cell = grid[y, x] - # tile_img = OvercookedV2Visualizer._render_tile( - # cell, - # highlight=highlight_mask[y, x], - # tile_size=tile_size, - # ) - - # _set_tile(x, y, tile_img) - img_grid = jax.vmap(jax.vmap(self._render_tile))(grid) print("img_grid", img_grid.shape) - # img = img_grid - grid_rows, grid_cols, tile_height, tile_width, channels = img_grid.shape - # Reshape and transpose to merge the grid big_image = img_grid.transpose(0, 2, 1, 3, 4).reshape( grid_rows * tile_height, grid_cols * tile_width, channels ) + print("big_image", big_image.shape) + return big_image def close(self): diff --git a/jaxmarl/viz/test.py b/jaxmarl/viz/test.py new file mode 100644 index 00000000..81fcf122 --- /dev/null +++ b/jaxmarl/viz/test.py @@ -0,0 +1,113 @@ +""" +Short introduction to running the Overcooked environment and visualising it using random actions. +""" + +import jax +from jaxmarl import make +from jaxmarl.viz.overcooked_v2_visualizer import OvercookedV2Visualizer +import time + +# Parameters + random keys +max_steps = 2 +key = jax.random.PRNGKey(0) + +# Get one of the classic layouts (cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit) +layout = "cramped_room" + +# Or make your own! +# custom_layout_grid = """ +# WWOWW +# WA W +# B P X +# W AW +# WWOWW +# """ +# layout = layout_grid_to_dict(custom_layout_grid) + +# Instantiate environment +env = make("overcooked_v2", layout=layout, max_steps=max_steps) + + +def part_1(key): + key, key_r, key_a = jax.random.split(key, 3) + + obs, state = env.reset(key_r) + print("list of agents in environment", env.agents) + + # Sample random actions + key_a = jax.random.split(key_a, env.num_agents) + actions = { + agent: env.action_space(agent).sample(key_a[i]) + for i, agent in enumerate(env.agents) + } + print("example action dict", actions) + + # state_seq = [] + # for _ in range(max_steps): + # state_seq.append(state) + # # Iterate random keys and sample actions + # key, key_s, key_a = jax.random.split(key, 3) + # key_a = jax.random.split(key_a, env.num_agents) + + # actions = { + # agent: env.action_space(agent).sample(key_a[i]) + # for i, agent in enumerate(env.agents) + # } + + # # Step environment + # obs, state, rewards, dones, infos = env.step(key_s, state, actions) + + def _step(state, key): + key_action, key_step = jax.random.split(key) + actions = { + agent: env.action_space(agent).sample(key_action) + for i, agent in enumerate(env.agents) + } + + # Step environment + obs, state, rewards, dones, infos = env.step(key_step, state, actions) + + return state, state + + keys = jax.random.split(key, max_steps) + _, state_seq = jax.lax.scan(_step, state, keys) + + return state_seq + + +def part_2(state_seq): + # Visualize + viz = OvercookedV2Visualizer() + + # print(state_seq) + + # Or save an animation + viz.animate(state_seq, filename="animation.gif", agent_view_size=1) + + +# viz = OvercookedV2Visualizer() + +# # # Render to screen +# # for s in state_seq: +# # viz.render(env.agent_view_size, s, highlight=False) +# # time.sleep(0.25) + +# # # Or save an animation +# viz.animate(state_seq, agent_view_size=5, filename="animation.gif") + + +if __name__ == "__main__": + + start_time = time.time() # Renamed variable to avoid conflict + + with jax.disable_jit(True): + state_seq = part_1(key) + print("done part 1") + + part_2(state_seq) + + print("done") + + print( + "time taken", time.time() - start_time + ) # Updated to use the new variable name