Skip to content

Commit

Permalink
feat(overcooked): Auto cook pots by default
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Jul 3, 2024
1 parent 02b1c97 commit 6887304
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
14 changes: 13 additions & 1 deletion jaxmarl/environments/overcooked_v2/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@
WWWWWWWWW
"""

more_fun_coordination = """
WWWWWWWWW
W X W
RA P A1
0 P 2
W B W
WWWWWWWWW
"""


@dataclass
class Layout:
Expand Down Expand Up @@ -264,6 +273,9 @@ def layout_grid_to_dict(grid, recipe=None, possible_recipes=None):
"two_rooms_simple": layout_grid_to_dict(two_rooms_simple),
"long_room": layout_grid_to_dict(long_room, recipe=[0, 0, 0]),
"fun_coordination": layout_grid_to_dict(
fun_coordination, possible_recipes=[[2, 2, 0], [3, 3, 1]]
fun_coordination, possible_recipes=[[0, 0, 2], [1, 1, 3]]
),
"more_fun_coordination": layout_grid_to_dict(
more_fun_coordination, possible_recipes=[[0, 1, 1], [0, 2, 2]]
),
}
45 changes: 26 additions & 19 deletions jaxmarl/environments/overcooked_v2/overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
agent_view_size: Optional[int] = None,
random_reset: bool = False,
random_agent_positions: bool = False,
start_cooking_interaction: bool = False,
):
"""
Initializes the OvercookedV2 environment.
Expand All @@ -103,6 +104,7 @@ def __init__(
agent_view_size (Optional[int]): The number of blocks the agent can view in each direction, None for full grid.
random_reset (bool): Whether to reset the environment with random agent positions, inventories and pot states.
random_agent_positions (bool): Whether to randomize agent positions. Agents will not be moved outside of their room if they are placed in an enclosed space.
start_cooking_interaction (bool): If false the pot starts cooking automatically once three ingredients are added, if true the pot starts cooking only after the agent interacts with it.
"""

if isinstance(layout, str):
Expand Down Expand Up @@ -139,6 +141,10 @@ def __init__(
self.random_reset = random_reset
self.random_agent_positions = random_agent_positions

self.start_cooking_interaction = jnp.array(
start_cooking_interaction, dtype=jnp.bool
)

self.enclosed_spaces = compute_enclosed_spaces(
layout.static_objects == StaticObject.EMPTY,
)
Expand Down Expand Up @@ -217,11 +223,11 @@ def reset(
subkey, (), 0, len(self.possible_recipes)
)
fixed_recipe = self.possible_recipes[fixed_recipe_idx]
print("fixed_recipe: ", fixed_recipe)
# print("fixed_recipe: ", fixed_recipe)

recipe = DynamicObject.get_recipe_encoding(fixed_recipe)

print("Recipe: ", fixed_recipe)
# print("Recipe: ", fixed_recipe)

state = State(
agents=agents,
Expand All @@ -233,8 +239,6 @@ def reset(

key, key_randomize = jax.random.split(key)
if self.random_reset:
print("Random reset")

state = self._randomize_state(state, key_randomize)
elif self.random_agent_positions:
state = self._randomize_agent_positions(state, key_randomize)
Expand Down Expand Up @@ -271,7 +275,7 @@ def _select_agent_position(taken_mask, x):
_select_agent_position, taken_mask, (agents.pos, keys)
)

print("agent_positions: ", agent_positions)
# print("agent_positions: ", agent_positions)

return state.replace(agents=agents.replace(pos=agent_positions))

Expand Down Expand Up @@ -399,7 +403,7 @@ def _sample_counter_state(key):
key_grid = jax.random.split(subkey, (self.height, self.width))
new_grid = jax.vmap(jax.vmap(_sample_grid_states_wrapper))(grid, key_grid)

print("new_grid: ", new_grid)
# print("new_grid: ", new_grid)

return state.replace(
agents=agents.replace(inventory=agent_inventories),
Expand Down Expand Up @@ -651,7 +655,7 @@ def step_agents(
) -> Tuple[State, float]:
grid = state.grid

print("actions: ", actions)
# print("actions: ", actions)

# Move action:
# 1. move agent to new position (if possible on the grid)
Expand Down Expand Up @@ -734,7 +738,7 @@ def _interact_wrapper(carry, x):
def _interact(carry, agent):
grid, reward = carry

print("interact: ", agent.pos, agent.dir)
# print("interact: ", agent.pos, agent.dir)

new_grid, new_agent, interact_reward, shaped_reward = (
self.process_interact(
Expand Down Expand Up @@ -816,12 +820,12 @@ def process_interact(

inventory_is_empty = inventory == 0
inventory_is_ingredient = DynamicObject.is_ingredient(inventory)
print("inventory_is_ingredient: ", inventory_is_ingredient)
# print("inventory_is_ingredient: ", inventory_is_ingredient)
inventory_is_plate = inventory == DynamicObject.PLATE
inventory_is_dish = (inventory & DynamicObject.COOKED) != 0

merged_ingredients = interact_ingredients + inventory
print("merged_ingredients: ", merged_ingredients)
# print("merged_ingredients: ", merged_ingredients)

pot_is_cooking = object_is_pot * (interact_extra > 0)
pot_is_cooked = object_is_pot * (
Expand All @@ -843,12 +847,12 @@ def process_interact(
+ object_is_wall * ~object_has_no_ingredients * inventory_is_empty
)

print("successful_pickup: ", successful_pickup)
print("object_is_pile: ", object_is_pile)
print("inventory_is_empty: ", inventory_is_empty)
# print("successful_pickup: ", successful_pickup)
# print("object_is_pile: ", object_is_pile)
# print("inventory_is_empty: ", inventory_is_empty)

pot_full = DynamicObject.ingredient_count(interact_ingredients) == 3
print("pot_full: ", pot_full)
# print("pot_full: ", pot_full)

successful_pot_placement = pot_is_idle * inventory_is_ingredient * ~pot_full
ingredient_selector = inventory | (inventory << 1)
Expand Down Expand Up @@ -877,18 +881,20 @@ def process_interact(
object_is_plate_pile * DynamicObject.PLATE
+ object_is_ingredient_pile * StaticObject.get_ingredient(interact_item)
)
print("pile_ingredient: ", pile_ingredient)
# print("pile_ingredient: ", pile_ingredient)

new_ingredients = (
successful_drop * merged_ingredients + no_effect * interact_ingredients
)
pot_full_after_drop = DynamicObject.ingredient_count(new_ingredients) == 3

successful_pot_start_cooking = (
pot_is_idle * ~object_has_no_ingredients * inventory_is_empty
)
is_pot_start_cooking_useful = interact_ingredients == recipe
shaped_reward += (
successful_pot_start_cooking
self.start_cooking_interaction
* successful_pot_start_cooking
* is_pot_start_cooking_useful
# * jax.lax.select(
# is_pot_start_cooking_useful,
Expand All @@ -897,8 +903,9 @@ def process_interact(
# )
* SHAPED_REWARDS["POT_START_COOKING"]
)
auto_cook = pot_is_idle & pot_full_after_drop & ~self.start_cooking_interaction
new_extra = jax.lax.select(
successful_pot_start_cooking,
successful_pot_start_cooking | auto_cook,
POT_COOK_TIME,
interact_extra,
)
Expand All @@ -910,11 +917,11 @@ def process_interact(
successful_pickup * (pile_ingredient + merged_ingredients)
+ no_effect * inventory
)
print("new_inventory: ", new_inventory)
# print("new_inventory: ", new_inventory)
new_agent = agent.replace(inventory=new_inventory)

is_correct_recipe = inventory == plated_recipe
print("is_correct_recipe: ", is_correct_recipe)
# print("is_correct_recipe: ", is_correct_recipe)
reward = (
jnp.array(successful_delivery & is_correct_recipe, dtype=float)
* DELIVERY_REWARD
Expand Down

0 comments on commit 6887304

Please sign in to comment.