Skip to content

Commit

Permalink
feat(overcooked): Add custom overcooked logger
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Jun 28, 2024
1 parent 4cca5f7 commit 61cef11
Showing 1 changed file with 173 additions and 38 deletions.
211 changes: 173 additions & 38 deletions jaxmarl/wrappers/baselines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Wrappers for use with jaxmarl baselines. """

import os
import jax
import jax.numpy as jnp
Expand All @@ -10,17 +11,20 @@
# from gymnax.environments import environment, spaces
from gymnax.environments.spaces import Box as BoxGymnax, Discrete as DiscreteGymnax
from typing import Dict, Optional, List, Tuple, Union
from jaxmarl.environments.overcooked_v2.common import DynamicObject
from jaxmarl.environments.spaces import Box, Discrete, MultiDiscrete
from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State

from safetensors.flax import save_file, load_file
from flax.traverse_util import flatten_dict, unflatten_dict


def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
flattened_dict = flatten_dict(params, sep=',')
flattened_dict = flatten_dict(params, sep=",")
save_file(flattened_dict, filename)

def load_params(filename:Union[str, os.PathLike]) -> Dict:

def load_params(filename: Union[str, os.PathLike]) -> Dict:
flattened_dict = load_file(filename)
return unflatten_dict(flattened_dict, sep=",")

Expand Down Expand Up @@ -101,10 +105,97 @@ def step(
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
return obs, state, reward, done, info


@struct.dataclass
class OvercookedV2LogEnvState:
env_state: State
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
returned_episode_recipe_returns: Dict[str, float]


class OvercookedV2LogWrapper(JaxMARLWrapper):
def __init__(self, env: MultiAgentEnv, replace_info: bool = False):
super().__init__(env)
self.replace_info = replace_info

self.recipe_dict = {
f"{recipe[0]}_{recipe[1]}_{recipe[2]}": DynamicObject.get_recipe_encoding(
recipe
)
for recipe in env.possible_recipes
}

@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
obs, env_state = self._env.reset(key)

recipe_returns = {
r: jnp.zeros((self._env.num_agents,)) for r in self.recipe_dict
}

state = OvercookedV2LogEnvState(
env_state,
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
recipe_returns,
)
return obs, state

@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: OvercookedV2LogEnvState,
action: Union[int, float],
) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action
)
ep_done = done["__all__"]
batch_reward = self._batchify_floats(reward)
new_episode_return = state.episode_returns + self._batchify_floats(reward)
new_episode_length = state.episode_lengths + 1
new_won_episode = (batch_reward >= 1.0).astype(jnp.float32)

updated_recipe_returns = {
id: jax.lax.select(
(state.env_state.recipe == self.recipe_dict[id]) & ep_done,
new_episode_return,
old_episode_return,
)
for id, old_episode_return in state.returned_episode_recipe_returns.items()
}

state = OvercookedV2LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - ep_done),
episode_lengths=new_episode_length * (1 - ep_done),
returned_episode_returns=jax.lax.select(
ep_done, new_episode_return, state.returned_episode_returns
),
returned_episode_lengths=jax.lax.select(
ep_done, new_episode_length, state.returned_episode_lengths
),
returned_episode_recipe_returns=updated_recipe_returns,
)
if self.replace_info:
info = {}
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
info["returned_episode_recipe_returns"] = state.returned_episode_recipe_returns
return obs, state, reward, done, info


class MPELogWrapper(LogWrapper):
""" Times reward signal by number of agents within the environment,
to match the on-policy codebase. """
"""Times reward signal by number of agents within the environment,
to match the on-policy codebase."""

@partial(jax.jit, static_argnums=(0,))
def step(
self,
Expand All @@ -115,7 +206,9 @@ def step(
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action
)
rewardlog = jax.tree_map(lambda x: x*self._env.num_agents, reward) # As per on-policy codebase
rewardlog = jax.tree_map(
lambda x: x * self._env.num_agents, reward
) # As per on-policy codebase
ep_done = done["__all__"]
new_episode_return = state.episode_returns + self._batchify_floats(rewardlog)
new_episode_length = state.episode_lengths + 1
Expand All @@ -135,6 +228,7 @@ def step(
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
return obs, state, reward, done, info


@struct.dataclass
class SMAXLogEnvState:
env_state: State
Expand Down Expand Up @@ -209,7 +303,9 @@ def get_space_dim(space):
return np.prod(space.shape)
else:
print(space)
raise NotImplementedError('Current wrapper works only with Discrete/MultiDiscrete/Box action and obs spaces')
raise NotImplementedError(
"Current wrapper works only with Discrete/MultiDiscrete/Box action and obs spaces"
)


class CTRolloutManager(JaxMARLWrapper):
Expand All @@ -225,46 +321,69 @@ class CTRolloutManager(JaxMARLWrapper):
- global_reward is the sum of all agents' rewards.
"""

def __init__(self, env: MultiAgentEnv, batch_size:int, training_agents:List=None, preprocess_obs:bool=True):

def __init__(
self,
env: MultiAgentEnv,
batch_size: int,
training_agents: List = None,
preprocess_obs: bool = True,
):

super().__init__(env)

self.batch_size = batch_size

# the agents to train could differ from the total trainable agents in the env (f.i. if using pretrained agents)
# it's important to know it in order to compute properly the default global rewards and state
self.training_agents = self.agents if training_agents is None else training_agents
self.preprocess_obs = preprocess_obs
self.training_agents = (
self.agents if training_agents is None else training_agents
)
self.preprocess_obs = preprocess_obs

# TOREMOVE: this is because overcooked doesn't follow other envs conventions
if len(env.observation_spaces) == 0:
self.observation_spaces = {agent:self.observation_space() for agent in self.agents}
self.observation_spaces = {
agent: self.observation_space() for agent in self.agents
}
if len(env.action_spaces) == 0:
self.action_spaces = {agent:env.action_space() for agent in self.agents}
self.action_spaces = {agent: env.action_space() for agent in self.agents}

# batched action sampling
self.batch_samplers = {agent: jax.jit(jax.vmap(self.action_space(agent).sample, in_axes=0)) for agent in self.agents}
self.batch_samplers = {
agent: jax.jit(jax.vmap(self.action_space(agent).sample, in_axes=0))
for agent in self.agents
}

# assumes the observations are flattened vectors
self.max_obs_length = max(list(map(lambda x: get_space_dim(x), self.observation_spaces.values())))
self.max_action_space = max(list(map(lambda x: get_space_dim(x), self.action_spaces.values())))
self.max_obs_length = max(
list(map(lambda x: get_space_dim(x), self.observation_spaces.values()))
)
self.max_action_space = max(
list(map(lambda x: get_space_dim(x), self.action_spaces.values()))
)
self.obs_size = self.max_obs_length + len(self.agents)

# agents ids
self.agents_one_hot = {a:oh for a, oh in zip(self.agents, jnp.eye(len(self.agents)))}
self.agents_one_hot = {
a: oh for a, oh in zip(self.agents, jnp.eye(len(self.agents)))
}
# valid actions
self.valid_actions = {a:jnp.arange(u.n) for a, u in self.action_spaces.items()}
self.valid_actions_oh ={a:jnp.concatenate((jnp.ones(u.n), jnp.zeros(self.max_action_space - u.n))) for a, u in self.action_spaces.items()}
self.valid_actions = {a: jnp.arange(u.n) for a, u in self.action_spaces.items()}
self.valid_actions_oh = {
a: jnp.concatenate((jnp.ones(u.n), jnp.zeros(self.max_action_space - u.n)))
for a, u in self.action_spaces.items()
}

# custom global state and rewards for specific envs
if 'smax' in env.name.lower():
self.global_state = lambda obs, state: obs['world_state']
self.global_reward = lambda rewards: rewards[self.training_agents[0]]*10
elif 'overcooked' in env.name.lower():
self.global_state = lambda obs, state: jnp.concatenate([obs[agent].ravel() for agent in self.agents], axis=-1)
if "smax" in env.name.lower():
self.global_state = lambda obs, state: obs["world_state"]
self.global_reward = lambda rewards: rewards[self.training_agents[0]] * 10
elif "overcooked" in env.name.lower():
self.global_state = lambda obs, state: jnp.concatenate(
[obs[agent].ravel() for agent in self.agents], axis=-1
)
self.global_reward = lambda rewards: rewards[self.training_agents[0]]


@partial(jax.jit, static_argnums=0)
def batch_reset(self, key):
keys = jax.random.split(key, self.batch_size)
Expand All @@ -279,20 +398,32 @@ def batch_step(self, key, states, actions):
def wrapped_reset(self, key):
obs_, state = self._env.reset(key)
if self.preprocess_obs:
obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot)
obs = jax.tree_util.tree_map(
self._preprocess_obs,
{agent: obs_[agent] for agent in self.agents},
self.agents_one_hot,
)
else:
obs = obs_
obs["__all__"] = self.global_state(obs_, state)
return obs, state

@partial(jax.jit, static_argnums=0)
def wrapped_step(self, key, state, actions):
if 'hanabi' in self._env.name.lower():
actions = jax.tree_util.tree_map(lambda x:jnp.expand_dims(x, 0), actions)
if "hanabi" in self._env.name.lower():
actions = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0), actions)
obs_, state, reward, done, infos = self._env.step(key, state, actions)
if self.preprocess_obs:
obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot)
obs = jax.tree_util.tree_map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs) # ensure that the obs are 0s for done agents
obs = jax.tree_util.tree_map(
self._preprocess_obs,
{agent: obs_[agent] for agent in self.agents},
self.agents_one_hot,
)
obs = jax.tree_util.tree_map(
lambda d, o: jnp.where(d, 0.0, o),
{agent: done[agent] for agent in self.agents},
obs,
) # ensure that the obs are 0s for done agents
else:
obs = obs_
obs["__all__"] = self.global_state(obs_, state)
Expand All @@ -302,21 +433,25 @@ def wrapped_step(self, key, state, actions):
@partial(jax.jit, static_argnums=0)
def global_state(self, obs, state):
return jnp.concatenate([obs[agent] for agent in self.agents], axis=-1)

@partial(jax.jit, static_argnums=0)
def global_reward(self, reward):
return jnp.stack([reward[agent] for agent in self.training_agents]).sum(axis=0)
return jnp.stack([reward[agent] for agent in self.training_agents]).sum(axis=0)

def batch_sample(self, key, agent):
return self.batch_samplers[agent](jax.random.split(key, self.batch_size)).astype(int)
return self.batch_samplers[agent](
jax.random.split(key, self.batch_size)
).astype(int)

@partial(jax.jit, static_argnums=0)
def _preprocess_obs(self, arr, extra_features):
# flatten
arr = arr.flatten()
# pad the observation vectors to the maximum length
pad_width = [(0, 0)] * (arr.ndim - 1) + [(0, max(0, self.max_obs_length - arr.shape[-1]))]
arr = jnp.pad(arr, pad_width, mode='constant', constant_values=0)
pad_width = [(0, 0)] * (arr.ndim - 1) + [
(0, max(0, self.max_obs_length - arr.shape[-1]))
]
arr = jnp.pad(arr, pad_width, mode="constant", constant_values=0)
# concatenate the extra features
arr = jnp.concatenate((arr, extra_features), axis=-1)
return arr

0 comments on commit 61cef11

Please sign in to comment.