Skip to content

Commit

Permalink
change observation/action space and add support for more than 2 players
Browse files Browse the repository at this point in the history
  • Loading branch information
tindiz committed Jun 19, 2024
1 parent fdf6595 commit 9dbc311
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 102 deletions.
125 changes: 62 additions & 63 deletions jaxmarl/environments/hanabi/hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,25 @@ def __init__(
).squeeze()

self.action_set = jnp.arange(self.num_moves)
self.action_encoding = {}
for i, a in enumerate(self.action_set):
if a < hand_size:
move_type = f'D{i%hand_size}'
elif a < hand_size*2:
move_type = f'P{i%hand_size}'
elif a < hand_size*2 + num_colors:
move_type = f'H{self.color_map[i - hand_size*2]}'
elif a < hand_size*2 + (num_colors + num_ranks):
move_type = f'H{i - (hand_size*2+num_colors)+1}'
else:
move_type = 'N'
self.action_encoding[int(a)] = move_type

# useful ranges to know the type of the action
self.discard_action_range = jnp.arange(0, self.hand_size)
self.play_action_range = jnp.arange(self.hand_size, 2 * self.hand_size)
self.color_action_range = jnp.arange(
2 * self.hand_size, 2 * self.hand_size + self.num_colors
)
self.rank_action_range = jnp.arange(
2 * self.hand_size + self.num_colors,
2 * self.hand_size + self.num_colors + self.num_ranks,
)
self.action_encodings = []
for p in range(self.num_agents):
action_encoding = {}
for i, a in enumerate(self.action_set):
if self._is_discard(a):
move_type = f'D{i % hand_size}'
elif self._is_play(a):
move_type = f'P{i % hand_size}'
elif self._is_hint_color(a):
target_player, hint_idx = self._get_target_player_and_hint_index(p, a)
color = self.color_map[hint_idx]
move_type = f'H{color} to P{target_player}'
elif self._is_hint_rank(a):
target_player, hint_idx = self._get_target_player_and_hint_index(p, a)
move_type = f'H{hint_idx + 1} to P{target_player}'
else:
move_type = 'N'
action_encoding[i] = move_type
self.action_encodings.append(action_encoding)

# number of features
self.hands_n_feats = (
Expand Down Expand Up @@ -136,14 +131,21 @@ def __init__(
def reset(self, key: chex.PRNGKey) -> Tuple[Dict, State]:
"""Reset the environment and return the initial observation."""
state = self.reset_game(key)
obs = self.get_obs(state, state, action=20)
obs = self.get_obs(state, state, action=self.num_moves - 1)
return obs, state

@partial(jax.jit, static_argnums=[0])
def reset_from_deck(self, key: chex.PRNGKey, deck:chex.Array) -> Tuple[Dict, State]:
def reset_from_deck(self, deck: chex.Array) -> Tuple[Dict, State]:
"""Inject a deck in the game. Useful for testing."""
state = self.reset_game_from_deck(key, deck)
obs = self.get_obs(state, state, action=20)
state = self.reset_game_from_deck(deck)
obs = self.get_obs(state, state, action=self.num_moves - 1)
return obs, state

@partial(jax.jit, static_argnums=[0])
def reset_from_deck_of_pairs(self, deck: chex.Array) -> Tuple[Dict, State]:
"""Inject a deck from (color, rank) pairs."""
state = self.reset_game_from_deck_of_pairs(deck)
obs = self.get_obs(state, state, action=self.num_moves - 1)
return obs, state

@partial(jax.jit, static_argnums=[0])
Expand Down Expand Up @@ -175,7 +177,7 @@ def step_env(

obs = lax.stop_gradient(self.get_obs(new_state, old_state, action))

return (obs, lax.stop_gradient(new_state), rewards, dones, info)
return obs, lax.stop_gradient(new_state), rewards, dones, info

@partial(jax.jit, static_argnums=[0])
def get_obs(
Expand All @@ -184,7 +186,7 @@ def get_obs(
"""Get all agents' observations."""

# no agent-specific obs
board_fats = self.get_board_fats(new_state)
board_fats = self.get_board_feats(new_state)
discard_feats = self._binarize_discard_pile(new_state.discard_pile)

def _observe(aidx: int):
Expand Down Expand Up @@ -242,7 +244,7 @@ def _legal_moves(aidx: int, state: State) -> chex.Array:
hands = state.player_hands
info_tokens = state.info_tokens
# discard legal when discard tokens are not full
is_not_max_info_tokens = jnp.sum(state.info_tokens) < 8
is_not_max_info_tokens = jnp.sum(state.info_tokens) < self.max_info_tokens
legal_moves = legal_moves.at[move_idx : move_idx + self.hand_size].set(
is_not_max_info_tokens
)
Expand Down Expand Up @@ -305,60 +307,58 @@ def get_last_action_feats_(
):
"""Get the features of the last action taken"""

acting_player_index = old_state.cur_player_idx # absolute OH index
target_player_index = new_state.cur_player_idx # absolute OH index
acting_player_index = old_state.cur_player_idx
acting_player_relative_index = jnp.roll(
acting_player_index, -aidx
) # relative OH index

target_player, hint_idx = self._get_target_player_and_hint_index(aidx, action)
target_player_relative_index = jnp.roll(
target_player_index, -aidx
jax.nn.one_hot(target_player, num_classes=self.num_agents), -aidx
) # relative OH index

# in obl the encoding order here is: play, discard, reveal_c, reveal_r
move_type = jnp.where( # hard encoded but hey ho let's go
(action >= 0) & (action < 5), # discard
jnp.array([0, 1, 0, 0]),
move_type = jnp.where(
self._is_play(action),
jnp.array([1, 0, 0, 0]),
jnp.where(
(action >= 5) & (action < 10), # play
jnp.array([1, 0, 0, 0]),
self._is_discard(action),
jnp.array([0, 1, 0, 0]),
jnp.where(
(action >= 10) & (action < 15), # reveal_c
self._is_hint_color(action),
jnp.array([0, 0, 1, 0]),
jnp.where(
(action >= 15) & (action < 20), # reveal_r
self._is_hint_rank(action),
jnp.array([0, 0, 0, 1]),
jnp.array([0, 0, 0, 0]), # invalid
jnp.array([0, 0, 0, 0]),
),
),
),
)
)

target_player_relative_index_feat = jnp.where(
action >= 10, # only for hint actions
self._is_hint(action), # only for hint actions
target_player_relative_index,
jnp.zeros(self.num_agents),
)

# get the hand of the target player
target_hand = new_state.player_hands[
jnp.nonzero(target_player_index, size=1)[0][0]
]
target_hand = new_state.player_hands[target_player]

color_revealed = jnp.where( # which color was revealed by action (oh)?
action == self.color_action_range,
1.0,
jnp.zeros(self.color_action_range.size),
color_revealed = jnp.where(
self._is_hint_color(action),
jax.nn.one_hot(hint_idx, num_classes=self.num_colors),
jnp.zeros(self.num_colors)
)

rank_revealed = jnp.where( # which rank was revealed by action (oh)?
action == self.rank_action_range,
1.0,
jnp.zeros(self.rank_action_range.size),
rank_revealed = jnp.where(
self._is_hint_rank(action),
jax.nn.one_hot(hint_idx, num_classes=self.num_ranks),
jnp.zeros(self.num_ranks)
)

# cards that have the color that was revealed
color_revealed_cards = jnp.where(
(target_hand.sum(axis=(2)) == color_revealed).all(
(target_hand.sum(axis=2) == color_revealed).all(
axis=1
), # color of the card==color reveled
1,
Expand All @@ -367,7 +367,7 @@ def get_last_action_feats_(

# cards that have the color that was revealed
rank_revealed_cards = jnp.where(
(target_hand.sum(axis=(1)) == rank_revealed).all(
(target_hand.sum(axis=1) == rank_revealed).all(
axis=1
), # color of the card==color reveled
1,
Expand All @@ -379,8 +379,8 @@ def get_last_action_feats_(

# card that was played-discarded
pos_played_discarded = jnp.where(
action < 2 * self.hand_size,
jnp.arange(self.hand_size) == action % self.hand_size,
jnp.logical_or(self._is_play(action), self._is_discard(action)),
jnp.arange(self.hand_size) == (action % self.hand_size),
jnp.zeros(self.hand_size),
)
actor_hand_before = old_state.player_hands[
Expand All @@ -402,7 +402,7 @@ def get_last_action_feats_(

# "added info token" boolean is present only when you get an info from playing the 5 of the color
added_info_tokens = jnp.where(
(action >= 5) & (action < 10),
self._is_play(action),
new_state.info_tokens.sum() > old_state.info_tokens.sum(),
0,
)
Expand Down Expand Up @@ -443,9 +443,8 @@ def get_last_action_feats(

return last_action


@partial(jax.jit, static_argnums=[0])
def get_board_fats(self, state: State):
def get_board_feats(self, state: State):
"""Get the features of the board."""
# by default the fireworks are incremental, i.e. [1,1,0,0,0] one and two are in the board
# must be OH of only the highest rank, i.e. [0,1,0,0,0]
Expand Down
Loading

0 comments on commit 9dbc311

Please sign in to comment.