Skip to content

Commit

Permalink
feat(overcooked): loads of viz upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Jun 30, 2024
1 parent 7e97093 commit ebebadd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 59 deletions.
77 changes: 18 additions & 59 deletions jaxmarl/viz/overcooked_v2_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from jaxmarl.environments.overcooked_v2.utils import compute_view_box
from jaxmarl.viz.window import Window
import jaxmarl.viz.grid_rendering as rendering
import jaxmarl.viz.grid_rendering_v2 as rendering
import jax
import jax.numpy as jnp
from jaxmarl.environments.overcooked_v2.common import StaticObject, DynamicObject
Expand Down Expand Up @@ -49,10 +49,9 @@ class OvercookedV2Visualizer:

tile_cache = {}

def __init__(self, agent_view_size=None, tile_size=TILE_PIXELS, subdivs=3):
def __init__(self, tile_size=TILE_PIXELS, subdivs=3):
self.window = None

self.agent_view_size = agent_view_size
self.tile_size = tile_size
self.subdivs = subdivs

Expand All @@ -64,34 +63,28 @@ def show(self, block=False):
self._lazy_init_window()
self.window.show(block=block)

def render(self, state):
def render(self, state, agent_view_size=None):
"""Method for rendering the state in a window. Esp. useful for interactive mode."""
self._lazy_init_window()

img = self._render_state(state)
img = self._render_state(state, agent_view_size)

self.window.show_img(img)

def animate(self, state_seq, filename="animation.gif"):
def animate(self, state_seq, filename="animation.gif", agent_view_size=None):
"""Animate a gif give a state sequence and save if to file."""

# def get_frame(state):
# frame = OvercookedV2Visualizer._render_state(
# state, agent_view_size=agent_view_size
# )
# return frame

# frame_seq = [get_frame(state) for state in state_seq]

frame_seq = jax.vmap(self._render_state)(state_seq)
frame_seq = jax.vmap(self._render_state, in_axes=(0, None))(
state_seq, agent_view_size
)
# print("frame_seq", frame_seq)
print("frame_seq.shape", frame_seq.shape)
print("frame_seq.dtype", frame_seq.dtype)

imageio.mimsave(filename, frame_seq, "GIF", duration=0.5)

@partial(jax.jit, static_argnums=(0,))
def _render_state(self, state):
@partial(jax.jit, static_argnums=(0, 2))
def _render_state(self, state, agent_view_size=None):
"""
Render the state
"""
Expand Down Expand Up @@ -287,16 +280,13 @@ def _render_ingredient_pile(cell, img):
StaticObject.PLATE_PILE: _render_plate_pile,
}

render_fns = [_render_empty] * (max(render_fns_dict.keys()) + 1)
render_fns = [_render_empty] * (max(render_fns_dict.keys()) + 2)
for key, value in render_fns_dict.items():
render_fns[key] = value
render_fns[-1] = _render_ingredient_pile

branch_idx = jnp.clip(static_object, 0, len(render_fns) - 1)

print("branch_idx", branch_idx)
print("render_fns", render_fns)

return jax.lax.switch(
branch_idx,
render_fns,
Expand All @@ -315,9 +305,10 @@ def _render_pot(cell, img):
ingredients = DynamicObject.get_ingredient_idx_list_jit(ingredients)
has_ingredients = ingredients[0] != -1

rendering.fill_coords(img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"])
img = rendering.fill_coords(
img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]
)

# if len(ingredients) > 0:
ingredient_fns = [
rendering.point_in_circle(*coord, 0.13)
for coord in [(0.23, 0.33), (0.77, 0.33), (0.50, 0.33)]
Expand Down Expand Up @@ -377,10 +368,10 @@ def _render_tile(
)

# Draw the grid lines (top and left edges)
rendering.fill_coords(
img = rendering.fill_coords(
img, rendering.point_in_rect(0, 0.031, 0, 1), COLORS["grey"]
)
rendering.fill_coords(
img = rendering.fill_coords(
img, rendering.point_in_rect(0, 1, 0, 0.031), COLORS["grey"]
)

Expand All @@ -389,15 +380,9 @@ def _render_tile(
# if highlight:
# rendering.highlight_img(img)

print("img", img.shape)
print("img.dtype", img.dtype)

# Downsample the image to perform supersampling/anti-aliasing
img = rendering.downsample(img, self.subdivs)

print("img", img.shape)
print("img.dtype", img.dtype)

# Cache the rendered tile
# OvercookedV2Visualizer.tile_cache[key] = img

Expand All @@ -411,44 +396,18 @@ def _render_grid(
if highlight_mask is None:
highlight_mask = jnp.zeros(shape=grid.shape[:2], dtype=bool)

# # Compute the total grid size in pixels
# width_px = grid.shape[1] * tile_size
# height_px = grid.shape[0] * tile_size

# img = jnp.zeros(shape=(height_px, width_px, 3), dtype=jnp.uint8)

# def _set_tile(x, y, tile_img):
# ymin = y * tile_size
# ymax = (y + 1) * tile_size
# xmin = x * tile_size
# xmax = (x + 1) * tile_size
# img[ymin:ymax, xmin:xmax, :] = tile_img

# # Render the grid
# for y in range(grid.shape[0]):
# for x in range(grid.shape[1]):
# cell = grid[y, x]
# tile_img = OvercookedV2Visualizer._render_tile(
# cell,
# highlight=highlight_mask[y, x],
# tile_size=tile_size,
# )

# _set_tile(x, y, tile_img)

img_grid = jax.vmap(jax.vmap(self._render_tile))(grid)

print("img_grid", img_grid.shape)

# img = img_grid

grid_rows, grid_cols, tile_height, tile_width, channels = img_grid.shape

# Reshape and transpose to merge the grid
big_image = img_grid.transpose(0, 2, 1, 3, 4).reshape(
grid_rows * tile_height, grid_cols * tile_width, channels
)

print("big_image", big_image.shape)

return big_image

def close(self):
Expand Down
113 changes: 113 additions & 0 deletions jaxmarl/viz/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Short introduction to running the Overcooked environment and visualising it using random actions.
"""

import jax
from jaxmarl import make
from jaxmarl.viz.overcooked_v2_visualizer import OvercookedV2Visualizer
import time

# Parameters + random keys
max_steps = 2
key = jax.random.PRNGKey(0)

# Get one of the classic layouts (cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit)
layout = "cramped_room"

# Or make your own!
# custom_layout_grid = """
# WWOWW
# WA W
# B P X
# W AW
# WWOWW
# """
# layout = layout_grid_to_dict(custom_layout_grid)

# Instantiate environment
env = make("overcooked_v2", layout=layout, max_steps=max_steps)


def part_1(key):
key, key_r, key_a = jax.random.split(key, 3)

obs, state = env.reset(key_r)
print("list of agents in environment", env.agents)

# Sample random actions
key_a = jax.random.split(key_a, env.num_agents)
actions = {
agent: env.action_space(agent).sample(key_a[i])
for i, agent in enumerate(env.agents)
}
print("example action dict", actions)

# state_seq = []
# for _ in range(max_steps):
# state_seq.append(state)
# # Iterate random keys and sample actions
# key, key_s, key_a = jax.random.split(key, 3)
# key_a = jax.random.split(key_a, env.num_agents)

# actions = {
# agent: env.action_space(agent).sample(key_a[i])
# for i, agent in enumerate(env.agents)
# }

# # Step environment
# obs, state, rewards, dones, infos = env.step(key_s, state, actions)

def _step(state, key):
key_action, key_step = jax.random.split(key)
actions = {
agent: env.action_space(agent).sample(key_action)
for i, agent in enumerate(env.agents)
}

# Step environment
obs, state, rewards, dones, infos = env.step(key_step, state, actions)

return state, state

keys = jax.random.split(key, max_steps)
_, state_seq = jax.lax.scan(_step, state, keys)

return state_seq


def part_2(state_seq):
# Visualize
viz = OvercookedV2Visualizer()

# print(state_seq)

# Or save an animation
viz.animate(state_seq, filename="animation.gif", agent_view_size=1)


# viz = OvercookedV2Visualizer()

# # # Render to screen
# # for s in state_seq:
# # viz.render(env.agent_view_size, s, highlight=False)
# # time.sleep(0.25)

# # # Or save an animation
# viz.animate(state_seq, agent_view_size=5, filename="animation.gif")


if __name__ == "__main__":

start_time = time.time() # Renamed variable to avoid conflict

with jax.disable_jit(True):
state_seq = part_1(key)
print("done part 1")

part_2(state_seq)

print("done")

print(
"time taken", time.time() - start_time
) # Updated to use the new variable name

0 comments on commit ebebadd

Please sign in to comment.