Skip to content

Commit

Permalink
feat(overcooked): Fix dynamic object viz
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Jun 30, 2024
1 parent ebebadd commit 59e4a96
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
14 changes: 7 additions & 7 deletions jaxmarl/environments/overcooked_v2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,44 +76,44 @@ def get_ingredient_idx_list(obj):
@staticmethod
def get_ingredient_idx_list_jit(obj):

def loop_body(carry):
def _loop_body(carry):
obj, pos, idx, res = carry
count = obj & 0x3

cond = jnp.arange(MAX_INGREDIENTS)
cond = (cond < pos) & (res == -1) & (cond < pos + count)
cond = (cond >= pos) & (res == -1) & (cond < pos + count)

res = jnp.where(
cond,
res,
idx,
res,
)

return (obj >> 2, pos + count, idx + 1, res)

def loop_cond(carry):
def _loop_cond(carry):
obj, pos, _, _ = carry
return (obj > 0) & (pos < MAX_INGREDIENTS)

initial_res = jnp.full((MAX_INGREDIENTS,), -1, dtype=jnp.int32)
carry = (obj >> 2, 0, 0, initial_res)

val = jax.lax.while_loop(loop_cond, loop_body, carry)
val = jax.lax.while_loop(_loop_cond, _loop_body, carry)
return val[-1]

@staticmethod
def get_ingredient_idx(obj):

def _body_fun(val):
obj, idx, res = val
new_res = jax.lax.select(obj & 0x3, idx, res)
new_res = jax.lax.select(obj & 0x3 != 0, idx, res)
return (obj >> 2, idx + 1, new_res)

def _cond_fun(val):
obj, _, res = val
return (obj > 0) & (res == -1)

initial_val = (obj, 0, -1)
initial_val = (obj >> 2, 0, -1)
val = jax.lax.while_loop(_cond_fun, _body_fun, initial_val)
return val[-1]

Expand Down
52 changes: 33 additions & 19 deletions jaxmarl/viz/overcooked_v2_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,31 +137,45 @@ def _render_dynamic_item(
ingredient_fn=rendering.point_in_circle(0.5, 0.5, 0.15),
dish_positions=jnp.array([(0.5, 0.4), (0.4, 0.6), (0.6, 0.6)]),
):
img_plate = rendering.fill_coords(img, plate_fn, COLORS["white"])
img = jax.lax.select(ingredients & DynamicObject.PLATE, img_plate, img)
def _no_op(img, ingredients):
return img

# if DynamicObject.is_ingredient(ingredients):
idx = DynamicObject.get_ingredient_idx(ingredients)
img_ing = rendering.fill_coords(img, ingredient_fn, INGREDIENT_COLORS[idx])
img = jax.lax.select(DynamicObject.is_ingredient(ingredients), img_ing, img)
def _render_plate(img, ingredients):
return rendering.fill_coords(img, plate_fn, COLORS["white"])

# if ingredients & DynamicObject.COOKED:
def _render_cooked_ingredient(img, x):
idx, ingredient_idx = x
def _render_ingredient(img, ingredients):
idx = DynamicObject.get_ingredient_idx(ingredients)
return rendering.fill_coords(img, ingredient_fn, INGREDIENT_COLORS[idx])

color = INGREDIENT_COLORS[ingredient_idx]
pos = dish_positions[idx]
ingredient_fn = rendering.point_in_circle(pos[0], pos[1], 0.1)
img_ing = rendering.fill_coords(img, ingredient_fn, color)
def _render_dish(img, ingredients):
img = rendering.fill_coords(img, plate_fn, COLORS["white"])
ingredient_indices = DynamicObject.get_ingredient_idx_list_jit(ingredients)

img = jax.lax.select(ingredient_idx != -1, img_ing, img)
return img, None
for idx, ingredient_idx in enumerate(ingredient_indices):
color = INGREDIENT_COLORS[ingredient_idx]
pos = dish_positions[idx]
ingredient_fn = rendering.point_in_circle(pos[0], pos[1], 0.1)
img_ing = rendering.fill_coords(img, ingredient_fn, color)

img = jax.lax.select(ingredient_idx != -1, img_ing, img)

return img

ingredient_indices = DynamicObject.get_ingredient_idx_list_jit(ingredients)
img, _ = jax.lax.scan(
_render_cooked_ingredient,
branches = jnp.array(
[
ingredients == 0,
ingredients == DynamicObject.PLATE,
DynamicObject.is_ingredient(ingredients),
ingredients & DynamicObject.COOKED,
]
)
branch_idx = jnp.argmax(branches)

img = jax.lax.switch(
branch_idx,
[_no_op, _render_plate, _render_ingredient, _render_dish],
img,
(jnp.arange(len(ingredient_indices)), ingredient_indices),
ingredients,
)

return img
Expand Down

0 comments on commit 59e4a96

Please sign in to comment.