Skip to content

Commit 29fbca7

Browse files
authored
Merge pull request #106 from FLAIROx/bugfix/hanabi-legal-actions
Hanabi Legal Actions Bugfix
2 parents 2df4446 + 1e8587c commit 29fbca7

File tree

2 files changed

+48
-54
lines changed

2 files changed

+48
-54
lines changed

jaxmarl/environments/hanabi/hanabi.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -241,63 +241,57 @@ def get_legal_moves(self, state: State) -> chex.Array:
241241

242242
@partial(jax.vmap, in_axes=[0, None])
243243
def _legal_moves(aidx: int, state: State) -> chex.Array:
244-
"""
245-
Legal moves encoding in order:
246-
- discard for all cards in hand
247-
- play for all cards in hand
248-
- hint for all colors and ranks for all other players
249-
"""
250-
move_idx = 0
244+
# all moves are illegal in the beginning
251245
legal_moves = jnp.zeros(self.num_moves)
252-
hands = state.player_hands
253-
info_tokens = state.info_tokens
254-
# discard legal when discard tokens are not full
255-
is_not_max_info_tokens = jnp.sum(state.info_tokens) < self.max_info_tokens
256-
legal_moves = legal_moves.at[move_idx : move_idx + self.hand_size].set(
257-
is_not_max_info_tokens
258-
)
259-
move_idx += self.hand_size
260-
# play moves always legal
261-
legal_moves = legal_moves.at[move_idx : move_idx + self.hand_size].set(1)
262-
move_idx += self.hand_size
263-
# hints depend on other player cards
264-
other_hands = jnp.delete(hands, aidx, axis=0, assume_unique_indices=True)
265-
266-
def _get_hints_for_hand(carry, unused):
267-
"""Generates valid hints given hand"""
268-
aidx, other_hands = carry
269-
hand = other_hands[aidx]
270-
271-
# get occurrences of each card
272-
card_counts = jnp.sum(hand, axis=0)
273-
# get occurrences of each color
274-
color_counts = jnp.sum(card_counts, axis=1)
275-
# get occurrences of each rank
276-
rank_counts = jnp.sum(card_counts, axis=0)
277-
# check which colors/ranks in hand
278-
colors_present = jnp.where(color_counts > 0, 1, 0)
279-
ranks_present = jnp.where(rank_counts > 0, 1, 0)
280-
281-
valid_hints = jnp.concatenate([colors_present, ranks_present])
282-
carry = (aidx + 1, other_hands)
283-
284-
return carry, valid_hints
285-
286-
_, valid_hints = lax.scan(
287-
_get_hints_for_hand, (0, other_hands), None, self.num_agents - 1
246+
all_player_hands = state.player_hands # (num_players, hand_size, num_colors, num_ranks)
247+
248+
# first get all cards one is holding since only these are legally playable
249+
my_hand = all_player_hands.at[aidx].get() # (hand_size, num_colors, num_ranks)
250+
holding_cards_idx = jax.vmap(lambda c: jnp.any(c))(my_hand)
251+
252+
# discard is legal when tokens are not full and when card is in hand
253+
can_get_token = jnp.sum(state.info_tokens) < self.max_info_tokens
254+
legal_discard_idx = holding_cards_idx * can_get_token
255+
legal_moves = legal_moves.at[self.discard_action_range].set(legal_discard_idx)
256+
257+
# play is legal for cards in hand - if empty card, not legal.
258+
legal_moves = legal_moves.at[self.play_action_range].set(holding_cards_idx)
259+
260+
# hints depend on cards held by other players and not our own
261+
other_players_hands = jnp.delete(
262+
all_player_hands, aidx, axis=0, assume_unique_indices=True
288263
)
289-
# make other player positions relative to current player
290-
valid_hints = jnp.roll(valid_hints, -aidx, axis=0)
291-
# include valid hints in legal moves
292-
num_hints = (self.num_agents - 1) * (self.num_colors + self.num_ranks)
293-
valid_hints = jnp.concatenate(valid_hints, axis=0)
294-
info_tokens_available = jnp.sum(info_tokens) != 0
295-
valid_hints *= info_tokens_available
296-
legal_moves = legal_moves.at[move_idx : move_idx + num_hints].set(
297-
valid_hints
264+
# adjust to have relative positions
265+
other_players_hands = jnp.roll(
266+
other_players_hands, -aidx, axis=0
298267
)
299268

300-
# only enable noop if not current player
269+
# cards can be hinted only if info tokens are available
270+
info_tokens_available = jnp.sum(state.info_tokens) > 0
271+
272+
# get all the colors that can be hinted
273+
def _hintable_colors(hand):
274+
# Hand: (num_cards, num_colors, num_ranks)
275+
card_colors = jnp.sum(hand, axis=2)
276+
hintable_colors = card_colors.any(axis=0)
277+
return hintable_colors
278+
279+
legal_color_hints = jax.vmap(_hintable_colors)(other_players_hands).ravel()
280+
legal_color_hints = legal_color_hints * info_tokens_available
281+
legal_moves = legal_moves.at[self.color_action_range].set(legal_color_hints)
282+
283+
# get all the ranks that can be hinted.
284+
def _hintable_ranks(hand):
285+
# Hand: (num_cards, num_colors, num_ranks)
286+
card_ranks = jnp.sum(hand, axis=1)
287+
hintable_ranks = card_ranks.any(axis=0)
288+
return hintable_ranks
289+
290+
legal_rank_hints = jax.vmap(_hintable_ranks)(other_players_hands).ravel()
291+
legal_rank_hints = legal_rank_hints * info_tokens_available
292+
legal_moves = legal_moves.at[self.rank_action_range].set(legal_rank_hints)
293+
294+
# Only legalize noop if not current player.
301295
cur_player = jnp.nonzero(state.cur_player_idx, size=1)[0][0]
302296
not_cur_player = aidx != cur_player
303297
legal_moves -= legal_moves * not_cur_player

jaxmarl/environments/hanabi/hanabi_game.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
self.max_info_tokens = max_info_tokens
6060
self.max_life_tokens = max_life_tokens
6161
self.num_cards_of_rank = num_cards_of_rank
62-
self.deck_size = jnp.sum(num_cards_of_rank) * num_colors
62+
self.deck_size = np.sum(num_cards_of_rank) * num_colors
6363
self.color_map = color_map
6464

6565
# action ranges - useful to know

0 commit comments

Comments
 (0)