From ce57ce015ca38749dda1c93d429e65c7a89a8baf Mon Sep 17 00:00:00 2001 From: Stefan Roesch Date: Thu, 25 Jul 2024 14:29:47 +0100 Subject: [PATCH 1/2] made coin game compatible with iql_rnn --- jaxmarl/environments/coin_game/coin_game.py | 45 +++++++++++++-------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/jaxmarl/environments/coin_game/coin_game.py b/jaxmarl/environments/coin_game/coin_game.py index 38f21895..6b945cf8 100644 --- a/jaxmarl/environments/coin_game/coin_game.py +++ b/jaxmarl/environments/coin_game/coin_game.py @@ -59,13 +59,13 @@ def __init__( self, num_inner_steps: int = 10, num_outer_steps: int = 10, - cnn: bool = True, + cnn: bool = False, egocentric: bool = False, payoff_matrix=[[1, 1, -2], [1, 1, -2]], ): super().__init__(num_agents=2) - self.agents = list(range(2)) + self.agents = [str(i) for i in list(range(2))] self.payoff_matrix = payoff_matrix # helper functions @@ -133,7 +133,8 @@ def _abs_position(state: EnvState) -> jnp.ndarray: [obs1[:, :, 1], obs1[:, :, 0], obs1[:, :, 3], obs1[:, :, 2]], axis=-1, ) - return obs1, obs2 + obs = {self.agents[0]: obs1, self.agents[1]: obs2} + return obs def _relative_position(state: EnvState) -> jnp.ndarray: """Assume canonical agent is red player""" @@ -188,19 +189,21 @@ def _state_to_obs(state: EnvState) -> jnp.ndarray: coop2=state.coop2, ) ) + obs = (obs1, obs2) + obs = {agent: obs for agent, obs in zip(self.agents, obs)} else: - obs1, obs2 = _abs_position(state) + obs = _abs_position(state) if not cnn: - return obs1.flatten(), obs2.flatten() - return obs1, obs2 + return {agent: obs[agent].flatten() for agent in obs} + return obs def _step( key: chex.PRNGKey, state: EnvState, actions: Tuple[int, int], ): - action_0, action_1 = actions + action_0, action_1 = list(actions.values()) new_red_pos = (state.red_pos + MOVES[action_0]) % 3 new_blue_pos = (state.blue_pos + MOVES[action_1]) % 3 red_reward, blue_reward = 0, 0 @@ -300,7 +303,7 @@ def _step( last_state=last_state, ) - obs1, obs2 = _state_to_obs(next_state) + obs = _state_to_obs(next_state) # now calculate if done for inner or outer episode inner_t = next_state.inner_t @@ -340,17 +343,27 @@ def _step( last_state=jnp.where(reset_inner, jnp.zeros(2), last_state), ) - obs1 = jnp.where(reset_inner, reset_obs[0], obs1) - obs2 = jnp.where(reset_inner, reset_obs[1], obs2) + obs = {agent: obs for agent, obs in zip(self.agents, [jnp.where(reset_inner, reset_obs[i], obs[i]) for i in obs])} blue_reward = jnp.where(reset_inner, 0.0, blue_reward) red_reward = jnp.where(reset_inner, 0.0, red_reward) + + # shared reward (social welfare/sum of agents individual rewards) + #rewards = {agent: reward for agent, reward in zip(self.agents, (sum((red_reward, blue_reward)), sum((red_reward, blue_reward))))} + + # individual reward + rewards = {agent: reward for agent, reward in zip(self.agents, (red_reward, blue_reward))} + + dones = {agent: reset_inner for agent in self.agents} + dones['__all__'] = reset_inner + + infos = {} return ( - (obs1, obs2), + obs, next_state, - (red_reward, blue_reward), - reset_inner, - {"discount": jnp.zeros((), dtype=jnp.int8)}, + rewards, + dones, + infos, ) def _reset( @@ -380,8 +393,8 @@ def _reset( coop2=state_stats, last_state=jnp.zeros(2), ) - obs1, obs2 = _state_to_obs(state) - return (obs1, obs2), state + obs = _state_to_obs(state) + return obs, state # overwrite Gymnax as it makes single-agent assumptions self.step = jax.jit(_step) From 8f3f033a0dfcc12faa5d91d9c6ef7fc9b895d6fd Mon Sep 17 00:00:00 2001 From: Stefan Nicolaas Roesch Date: Wed, 4 Sep 2024 14:23:58 +0000 Subject: [PATCH 2/2] added flag for shared reward --- jaxmarl/environments/coin_game/coin_game.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/jaxmarl/environments/coin_game/coin_game.py b/jaxmarl/environments/coin_game/coin_game.py index 6b945cf8..868ec1d3 100644 --- a/jaxmarl/environments/coin_game/coin_game.py +++ b/jaxmarl/environments/coin_game/coin_game.py @@ -61,6 +61,7 @@ def __init__( num_outer_steps: int = 10, cnn: bool = False, egocentric: bool = False, + shared_rewards: bool = False, payoff_matrix=[[1, 1, -2], [1, 1, -2]], ): @@ -348,11 +349,12 @@ def _step( blue_reward = jnp.where(reset_inner, 0.0, blue_reward) red_reward = jnp.where(reset_inner, 0.0, red_reward) - # shared reward (social welfare/sum of agents individual rewards) - #rewards = {agent: reward for agent, reward in zip(self.agents, (sum((red_reward, blue_reward)), sum((red_reward, blue_reward))))} - - # individual reward - rewards = {agent: reward for agent, reward in zip(self.agents, (red_reward, blue_reward))} + if shared_rewards: + # shared reward (social welfare\sum of agents individual rewards) + rewards = {agent: reward for agent, reward in zip(self.agents, (sum((red_reward, blue_reward)), sum((red_reward, blue_reward))))} + else: + # individual reward + rewards = {agent: reward for agent, reward in zip(self.agents, (red_reward, blue_reward))} dones = {agent: reset_inner for agent in self.agents} dones['__all__'] = reset_inner