Skip to content

Commit

Permalink
feat(overcooked): Improve recipe selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Jul 2, 2024
1 parent 2136f72 commit 0d4694b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 20 deletions.
19 changes: 16 additions & 3 deletions jaxmarl/environments/overcooked_v2/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
from .utils import get_possible_recipes
import itertools

cramped_room = """
WWPWW
Expand Down Expand Up @@ -100,7 +100,7 @@
fun_coordination = """
WWWWWWWWW
0 X 2
W P R
RA P AW
1 B 3
WWWWWWWWW
"""
Expand All @@ -126,13 +126,26 @@ class Layout:
# If possible_recipes is none, all possible recipes with the available ingredients will be considered
possible_recipes: Optional[List[List[int]]]

@staticmethod
def _get_possible_recipes(num_ingredients: int) -> List[List[int]]:
"""
Get all possible recipes given the number of ingredients.
"""
available_ingredients = list(range(num_ingredients)) * 3
raw_combinations = itertools.combinations(available_ingredients, 3)
unique_recipes = set(
tuple(sorted(combination)) for combination in raw_combinations
)

return [list(recipe) for recipe in unique_recipes]

def get_possible_recipes(self):
if self.recipe is not None:
possible_recipes = [self.recipe]
elif self.possible_recipes is not None:
possible_recipes = self.possible_recipes
else:
possible_recipes = get_possible_recipes(self.num_ingredients)
possible_recipes = self._get_possible_recipes(self.num_ingredients)
return possible_recipes


Expand Down
5 changes: 3 additions & 2 deletions jaxmarl/environments/overcooked_v2/overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jaxmarl.environments.overcooked_v2.utils import (
compute_view_box,
tree_select,
get_possible_recipes,
compute_enclosed_spaces,
)

Expand Down Expand Up @@ -133,7 +132,9 @@ def __init__(

self.max_steps = max_steps

self.possible_recipes = layout.get_possible_recipes()
self.possible_recipes = jnp.array(
layout.get_possible_recipes(), dtype=jnp.int32
)

self.random_reset = random_reset
self.random_agent_positions = random_agent_positions
Expand Down
14 changes: 0 additions & 14 deletions jaxmarl/environments/overcooked_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import jax.numpy as jnp
from typing import List
import itertools
import chex
from collections import deque
from .common import Position, Direction


Expand All @@ -26,18 +24,6 @@ def compute_view_box(x, y, agent_view_size, height, width):
return x_low, x_high, y_low, y_high


def get_possible_recipes(num_ingredients: int) -> List[List[int]]:
"""
Get all possible recipes given the number of ingredients.
"""
available_ingredients = list(range(num_ingredients)) * 3
raw_combinations = itertools.combinations(available_ingredients, 3)
unique_recipes = set(tuple(sorted(combination)) for combination in raw_combinations)
possible_recipes = jnp.array(list(unique_recipes), dtype=jnp.int32)

return possible_recipes


def compute_enclosed_spaces(empty_mask: jnp.ndarray) -> jnp.ndarray:
"""
Compute the enclosed spaces in the environment.
Expand Down
1 change: 0 additions & 1 deletion jaxmarl/wrappers/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def step(
batch_reward = self._batchify_floats(reward)
new_episode_return = state.episode_returns + self._batchify_floats(reward)
new_episode_length = state.episode_lengths + 1
new_won_episode = (batch_reward >= 1.0).astype(jnp.float32)

updated_recipe_returns = {
id: jax.lax.select(
Expand Down

0 comments on commit 0d4694b

Please sign in to comment.