diff --git a/jaxmarl/wrappers/baselines.py b/jaxmarl/wrappers/baselines.py index 4b476ce8a..1008bdc79 100644 --- a/jaxmarl/wrappers/baselines.py +++ b/jaxmarl/wrappers/baselines.py @@ -1,4 +1,5 @@ """ Wrappers for use with jaxmarl baselines. """ + import os import jax import jax.numpy as jnp @@ -10,17 +11,20 @@ # from gymnax.environments import environment, spaces from gymnax.environments.spaces import Box as BoxGymnax, Discrete as DiscreteGymnax from typing import Dict, Optional, List, Tuple, Union +from jaxmarl.environments.overcooked_v2.common import DynamicObject from jaxmarl.environments.spaces import Box, Discrete, MultiDiscrete from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State from safetensors.flax import save_file, load_file from flax.traverse_util import flatten_dict, unflatten_dict + def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None: - flattened_dict = flatten_dict(params, sep=',') + flattened_dict = flatten_dict(params, sep=",") save_file(flattened_dict, filename) -def load_params(filename:Union[str, os.PathLike]) -> Dict: + +def load_params(filename: Union[str, os.PathLike]) -> Dict: flattened_dict = load_file(filename) return unflatten_dict(flattened_dict, sep=",") @@ -101,10 +105,97 @@ def step( info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done) return obs, state, reward, done, info + +@struct.dataclass +class OvercookedV2LogEnvState: + env_state: State + episode_returns: float + episode_lengths: int + returned_episode_returns: float + returned_episode_lengths: int + returned_episode_recipe_returns: Dict[str, float] + + +class OvercookedV2LogWrapper(JaxMARLWrapper): + def __init__(self, env: MultiAgentEnv, replace_info: bool = False): + super().__init__(env) + self.replace_info = replace_info + + self.recipe_dict = { + f"{recipe[0]}_{recipe[1]}_{recipe[2]}": DynamicObject.get_recipe_encoding( + recipe + ) + for recipe in env.possible_recipes + } + + @partial(jax.jit, static_argnums=(0,)) + def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]: + obs, env_state = self._env.reset(key) + + recipe_returns = { + r: jnp.zeros((self._env.num_agents,)) for r in self.recipe_dict + } + + state = OvercookedV2LogEnvState( + env_state, + jnp.zeros((self._env.num_agents,)), + jnp.zeros((self._env.num_agents,)), + jnp.zeros((self._env.num_agents,)), + jnp.zeros((self._env.num_agents,)), + recipe_returns, + ) + return obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: OvercookedV2LogEnvState, + action: Union[int, float], + ) -> Tuple[chex.Array, LogEnvState, float, bool, dict]: + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action + ) + ep_done = done["__all__"] + batch_reward = self._batchify_floats(reward) + new_episode_return = state.episode_returns + self._batchify_floats(reward) + new_episode_length = state.episode_lengths + 1 + new_won_episode = (batch_reward >= 1.0).astype(jnp.float32) + + updated_recipe_returns = { + id: jax.lax.select( + (state.env_state.recipe == self.recipe_dict[id]) & ep_done, + new_episode_return, + old_episode_return, + ) + for id, old_episode_return in state.returned_episode_recipe_returns.items() + } + + state = OvercookedV2LogEnvState( + env_state=env_state, + episode_returns=new_episode_return * (1 - ep_done), + episode_lengths=new_episode_length * (1 - ep_done), + returned_episode_returns=jax.lax.select( + ep_done, new_episode_return, state.returned_episode_returns + ), + returned_episode_lengths=jax.lax.select( + ep_done, new_episode_length, state.returned_episode_lengths + ), + returned_episode_recipe_returns=updated_recipe_returns, + ) + if self.replace_info: + info = {} + info["returned_episode_returns"] = state.returned_episode_returns + info["returned_episode_lengths"] = state.returned_episode_lengths + info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done) + info["returned_episode_recipe_returns"] = state.returned_episode_recipe_returns + return obs, state, reward, done, info + + class MPELogWrapper(LogWrapper): - """ Times reward signal by number of agents within the environment, - to match the on-policy codebase. """ - + """Times reward signal by number of agents within the environment, + to match the on-policy codebase.""" + @partial(jax.jit, static_argnums=(0,)) def step( self, @@ -115,7 +206,9 @@ def step( obs, env_state, reward, done, info = self._env.step( key, state.env_state, action ) - rewardlog = jax.tree_map(lambda x: x*self._env.num_agents, reward) # As per on-policy codebase + rewardlog = jax.tree_map( + lambda x: x * self._env.num_agents, reward + ) # As per on-policy codebase ep_done = done["__all__"] new_episode_return = state.episode_returns + self._batchify_floats(rewardlog) new_episode_length = state.episode_lengths + 1 @@ -135,6 +228,7 @@ def step( info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done) return obs, state, reward, done, info + @struct.dataclass class SMAXLogEnvState: env_state: State @@ -209,7 +303,9 @@ def get_space_dim(space): return np.prod(space.shape) else: print(space) - raise NotImplementedError('Current wrapper works only with Discrete/MultiDiscrete/Box action and obs spaces') + raise NotImplementedError( + "Current wrapper works only with Discrete/MultiDiscrete/Box action and obs spaces" + ) class CTRolloutManager(JaxMARLWrapper): @@ -225,46 +321,69 @@ class CTRolloutManager(JaxMARLWrapper): - global_reward is the sum of all agents' rewards. """ - def __init__(self, env: MultiAgentEnv, batch_size:int, training_agents:List=None, preprocess_obs:bool=True): - + def __init__( + self, + env: MultiAgentEnv, + batch_size: int, + training_agents: List = None, + preprocess_obs: bool = True, + ): + super().__init__(env) - + self.batch_size = batch_size # the agents to train could differ from the total trainable agents in the env (f.i. if using pretrained agents) # it's important to know it in order to compute properly the default global rewards and state - self.training_agents = self.agents if training_agents is None else training_agents - self.preprocess_obs = preprocess_obs + self.training_agents = ( + self.agents if training_agents is None else training_agents + ) + self.preprocess_obs = preprocess_obs # TOREMOVE: this is because overcooked doesn't follow other envs conventions if len(env.observation_spaces) == 0: - self.observation_spaces = {agent:self.observation_space() for agent in self.agents} + self.observation_spaces = { + agent: self.observation_space() for agent in self.agents + } if len(env.action_spaces) == 0: - self.action_spaces = {agent:env.action_space() for agent in self.agents} - + self.action_spaces = {agent: env.action_space() for agent in self.agents} + # batched action sampling - self.batch_samplers = {agent: jax.jit(jax.vmap(self.action_space(agent).sample, in_axes=0)) for agent in self.agents} + self.batch_samplers = { + agent: jax.jit(jax.vmap(self.action_space(agent).sample, in_axes=0)) + for agent in self.agents + } # assumes the observations are flattened vectors - self.max_obs_length = max(list(map(lambda x: get_space_dim(x), self.observation_spaces.values()))) - self.max_action_space = max(list(map(lambda x: get_space_dim(x), self.action_spaces.values()))) + self.max_obs_length = max( + list(map(lambda x: get_space_dim(x), self.observation_spaces.values())) + ) + self.max_action_space = max( + list(map(lambda x: get_space_dim(x), self.action_spaces.values())) + ) self.obs_size = self.max_obs_length + len(self.agents) # agents ids - self.agents_one_hot = {a:oh for a, oh in zip(self.agents, jnp.eye(len(self.agents)))} + self.agents_one_hot = { + a: oh for a, oh in zip(self.agents, jnp.eye(len(self.agents))) + } # valid actions - self.valid_actions = {a:jnp.arange(u.n) for a, u in self.action_spaces.items()} - self.valid_actions_oh ={a:jnp.concatenate((jnp.ones(u.n), jnp.zeros(self.max_action_space - u.n))) for a, u in self.action_spaces.items()} + self.valid_actions = {a: jnp.arange(u.n) for a, u in self.action_spaces.items()} + self.valid_actions_oh = { + a: jnp.concatenate((jnp.ones(u.n), jnp.zeros(self.max_action_space - u.n))) + for a, u in self.action_spaces.items() + } # custom global state and rewards for specific envs - if 'smax' in env.name.lower(): - self.global_state = lambda obs, state: obs['world_state'] - self.global_reward = lambda rewards: rewards[self.training_agents[0]]*10 - elif 'overcooked' in env.name.lower(): - self.global_state = lambda obs, state: jnp.concatenate([obs[agent].ravel() for agent in self.agents], axis=-1) + if "smax" in env.name.lower(): + self.global_state = lambda obs, state: obs["world_state"] + self.global_reward = lambda rewards: rewards[self.training_agents[0]] * 10 + elif "overcooked" in env.name.lower(): + self.global_state = lambda obs, state: jnp.concatenate( + [obs[agent].ravel() for agent in self.agents], axis=-1 + ) self.global_reward = lambda rewards: rewards[self.training_agents[0]] - @partial(jax.jit, static_argnums=0) def batch_reset(self, key): keys = jax.random.split(key, self.batch_size) @@ -279,7 +398,11 @@ def batch_step(self, key, states, actions): def wrapped_reset(self, key): obs_, state = self._env.reset(key) if self.preprocess_obs: - obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) + obs = jax.tree_util.tree_map( + self._preprocess_obs, + {agent: obs_[agent] for agent in self.agents}, + self.agents_one_hot, + ) else: obs = obs_ obs["__all__"] = self.global_state(obs_, state) @@ -287,12 +410,20 @@ def wrapped_reset(self, key): @partial(jax.jit, static_argnums=0) def wrapped_step(self, key, state, actions): - if 'hanabi' in self._env.name.lower(): - actions = jax.tree_util.tree_map(lambda x:jnp.expand_dims(x, 0), actions) + if "hanabi" in self._env.name.lower(): + actions = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0), actions) obs_, state, reward, done, infos = self._env.step(key, state, actions) if self.preprocess_obs: - obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) - obs = jax.tree_util.tree_map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs) # ensure that the obs are 0s for done agents + obs = jax.tree_util.tree_map( + self._preprocess_obs, + {agent: obs_[agent] for agent in self.agents}, + self.agents_one_hot, + ) + obs = jax.tree_util.tree_map( + lambda d, o: jnp.where(d, 0.0, o), + {agent: done[agent] for agent in self.agents}, + obs, + ) # ensure that the obs are 0s for done agents else: obs = obs_ obs["__all__"] = self.global_state(obs_, state) @@ -302,21 +433,25 @@ def wrapped_step(self, key, state, actions): @partial(jax.jit, static_argnums=0) def global_state(self, obs, state): return jnp.concatenate([obs[agent] for agent in self.agents], axis=-1) - + @partial(jax.jit, static_argnums=0) def global_reward(self, reward): - return jnp.stack([reward[agent] for agent in self.training_agents]).sum(axis=0) - + return jnp.stack([reward[agent] for agent in self.training_agents]).sum(axis=0) + def batch_sample(self, key, agent): - return self.batch_samplers[agent](jax.random.split(key, self.batch_size)).astype(int) + return self.batch_samplers[agent]( + jax.random.split(key, self.batch_size) + ).astype(int) @partial(jax.jit, static_argnums=0) def _preprocess_obs(self, arr, extra_features): # flatten arr = arr.flatten() # pad the observation vectors to the maximum length - pad_width = [(0, 0)] * (arr.ndim - 1) + [(0, max(0, self.max_obs_length - arr.shape[-1]))] - arr = jnp.pad(arr, pad_width, mode='constant', constant_values=0) + pad_width = [(0, 0)] * (arr.ndim - 1) + [ + (0, max(0, self.max_obs_length - arr.shape[-1])) + ] + arr = jnp.pad(arr, pad_width, mode="constant", constant_values=0) # concatenate the extra features arr = jnp.concatenate((arr, extra_features), axis=-1) return arr