Skip to content

Commit 9dbc311

Browse files
committed
change observation/action space and add support for more than 2 players
1 parent fdf6595 commit 9dbc311

File tree

3 files changed

+227
-102
lines changed

3 files changed

+227
-102
lines changed

jaxmarl/environments/hanabi/hanabi.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,25 @@ def __init__(
6262
).squeeze()
6363

6464
self.action_set = jnp.arange(self.num_moves)
65-
self.action_encoding = {}
66-
for i, a in enumerate(self.action_set):
67-
if a < hand_size:
68-
move_type = f'D{i%hand_size}'
69-
elif a < hand_size*2:
70-
move_type = f'P{i%hand_size}'
71-
elif a < hand_size*2 + num_colors:
72-
move_type = f'H{self.color_map[i - hand_size*2]}'
73-
elif a < hand_size*2 + (num_colors + num_ranks):
74-
move_type = f'H{i - (hand_size*2+num_colors)+1}'
75-
else:
76-
move_type = 'N'
77-
self.action_encoding[int(a)] = move_type
78-
79-
# useful ranges to know the type of the action
80-
self.discard_action_range = jnp.arange(0, self.hand_size)
81-
self.play_action_range = jnp.arange(self.hand_size, 2 * self.hand_size)
82-
self.color_action_range = jnp.arange(
83-
2 * self.hand_size, 2 * self.hand_size + self.num_colors
84-
)
85-
self.rank_action_range = jnp.arange(
86-
2 * self.hand_size + self.num_colors,
87-
2 * self.hand_size + self.num_colors + self.num_ranks,
88-
)
65+
self.action_encodings = []
66+
for p in range(self.num_agents):
67+
action_encoding = {}
68+
for i, a in enumerate(self.action_set):
69+
if self._is_discard(a):
70+
move_type = f'D{i % hand_size}'
71+
elif self._is_play(a):
72+
move_type = f'P{i % hand_size}'
73+
elif self._is_hint_color(a):
74+
target_player, hint_idx = self._get_target_player_and_hint_index(p, a)
75+
color = self.color_map[hint_idx]
76+
move_type = f'H{color} to P{target_player}'
77+
elif self._is_hint_rank(a):
78+
target_player, hint_idx = self._get_target_player_and_hint_index(p, a)
79+
move_type = f'H{hint_idx + 1} to P{target_player}'
80+
else:
81+
move_type = 'N'
82+
action_encoding[i] = move_type
83+
self.action_encodings.append(action_encoding)
8984

9085
# number of features
9186
self.hands_n_feats = (
@@ -136,14 +131,21 @@ def __init__(
136131
def reset(self, key: chex.PRNGKey) -> Tuple[Dict, State]:
137132
"""Reset the environment and return the initial observation."""
138133
state = self.reset_game(key)
139-
obs = self.get_obs(state, state, action=20)
134+
obs = self.get_obs(state, state, action=self.num_moves - 1)
140135
return obs, state
141136

142137
@partial(jax.jit, static_argnums=[0])
143-
def reset_from_deck(self, key: chex.PRNGKey, deck:chex.Array) -> Tuple[Dict, State]:
138+
def reset_from_deck(self, deck: chex.Array) -> Tuple[Dict, State]:
144139
"""Inject a deck in the game. Useful for testing."""
145-
state = self.reset_game_from_deck(key, deck)
146-
obs = self.get_obs(state, state, action=20)
140+
state = self.reset_game_from_deck(deck)
141+
obs = self.get_obs(state, state, action=self.num_moves - 1)
142+
return obs, state
143+
144+
@partial(jax.jit, static_argnums=[0])
145+
def reset_from_deck_of_pairs(self, deck: chex.Array) -> Tuple[Dict, State]:
146+
"""Inject a deck from (color, rank) pairs."""
147+
state = self.reset_game_from_deck_of_pairs(deck)
148+
obs = self.get_obs(state, state, action=self.num_moves - 1)
147149
return obs, state
148150

149151
@partial(jax.jit, static_argnums=[0])
@@ -175,7 +177,7 @@ def step_env(
175177

176178
obs = lax.stop_gradient(self.get_obs(new_state, old_state, action))
177179

178-
return (obs, lax.stop_gradient(new_state), rewards, dones, info)
180+
return obs, lax.stop_gradient(new_state), rewards, dones, info
179181

180182
@partial(jax.jit, static_argnums=[0])
181183
def get_obs(
@@ -184,7 +186,7 @@ def get_obs(
184186
"""Get all agents' observations."""
185187

186188
# no agent-specific obs
187-
board_fats = self.get_board_fats(new_state)
189+
board_fats = self.get_board_feats(new_state)
188190
discard_feats = self._binarize_discard_pile(new_state.discard_pile)
189191

190192
def _observe(aidx: int):
@@ -242,7 +244,7 @@ def _legal_moves(aidx: int, state: State) -> chex.Array:
242244
hands = state.player_hands
243245
info_tokens = state.info_tokens
244246
# discard legal when discard tokens are not full
245-
is_not_max_info_tokens = jnp.sum(state.info_tokens) < 8
247+
is_not_max_info_tokens = jnp.sum(state.info_tokens) < self.max_info_tokens
246248
legal_moves = legal_moves.at[move_idx : move_idx + self.hand_size].set(
247249
is_not_max_info_tokens
248250
)
@@ -305,60 +307,58 @@ def get_last_action_feats_(
305307
):
306308
"""Get the features of the last action taken"""
307309

308-
acting_player_index = old_state.cur_player_idx # absolute OH index
309-
target_player_index = new_state.cur_player_idx # absolute OH index
310+
acting_player_index = old_state.cur_player_idx
310311
acting_player_relative_index = jnp.roll(
311312
acting_player_index, -aidx
312313
) # relative OH index
314+
315+
target_player, hint_idx = self._get_target_player_and_hint_index(aidx, action)
313316
target_player_relative_index = jnp.roll(
314-
target_player_index, -aidx
317+
jax.nn.one_hot(target_player, num_classes=self.num_agents), -aidx
315318
) # relative OH index
316319

317320
# in obl the encoding order here is: play, discard, reveal_c, reveal_r
318-
move_type = jnp.where( # hard encoded but hey ho let's go
319-
(action >= 0) & (action < 5), # discard
320-
jnp.array([0, 1, 0, 0]),
321+
move_type = jnp.where(
322+
self._is_play(action),
323+
jnp.array([1, 0, 0, 0]),
321324
jnp.where(
322-
(action >= 5) & (action < 10), # play
323-
jnp.array([1, 0, 0, 0]),
325+
self._is_discard(action),
326+
jnp.array([0, 1, 0, 0]),
324327
jnp.where(
325-
(action >= 10) & (action < 15), # reveal_c
328+
self._is_hint_color(action),
326329
jnp.array([0, 0, 1, 0]),
327330
jnp.where(
328-
(action >= 15) & (action < 20), # reveal_r
331+
self._is_hint_rank(action),
329332
jnp.array([0, 0, 0, 1]),
330-
jnp.array([0, 0, 0, 0]), # invalid
333+
jnp.array([0, 0, 0, 0]),
331334
),
332335
),
333-
),
336+
)
334337
)
335338

336339
target_player_relative_index_feat = jnp.where(
337-
action >= 10, # only for hint actions
340+
self._is_hint(action), # only for hint actions
338341
target_player_relative_index,
339342
jnp.zeros(self.num_agents),
340343
)
341344

342345
# get the hand of the target player
343-
target_hand = new_state.player_hands[
344-
jnp.nonzero(target_player_index, size=1)[0][0]
345-
]
346+
target_hand = new_state.player_hands[target_player]
346347

347-
color_revealed = jnp.where( # which color was revealed by action (oh)?
348-
action == self.color_action_range,
349-
1.0,
350-
jnp.zeros(self.color_action_range.size),
348+
color_revealed = jnp.where(
349+
self._is_hint_color(action),
350+
jax.nn.one_hot(hint_idx, num_classes=self.num_colors),
351+
jnp.zeros(self.num_colors)
351352
)
352-
353-
rank_revealed = jnp.where( # which rank was revealed by action (oh)?
354-
action == self.rank_action_range,
355-
1.0,
356-
jnp.zeros(self.rank_action_range.size),
353+
rank_revealed = jnp.where(
354+
self._is_hint_rank(action),
355+
jax.nn.one_hot(hint_idx, num_classes=self.num_ranks),
356+
jnp.zeros(self.num_ranks)
357357
)
358358

359359
# cards that have the color that was revealed
360360
color_revealed_cards = jnp.where(
361-
(target_hand.sum(axis=(2)) == color_revealed).all(
361+
(target_hand.sum(axis=2) == color_revealed).all(
362362
axis=1
363363
), # color of the card==color reveled
364364
1,
@@ -367,7 +367,7 @@ def get_last_action_feats_(
367367

368368
# cards that have the color that was revealed
369369
rank_revealed_cards = jnp.where(
370-
(target_hand.sum(axis=(1)) == rank_revealed).all(
370+
(target_hand.sum(axis=1) == rank_revealed).all(
371371
axis=1
372372
), # color of the card==color reveled
373373
1,
@@ -379,8 +379,8 @@ def get_last_action_feats_(
379379

380380
# card that was played-discarded
381381
pos_played_discarded = jnp.where(
382-
action < 2 * self.hand_size,
383-
jnp.arange(self.hand_size) == action % self.hand_size,
382+
jnp.logical_or(self._is_play(action), self._is_discard(action)),
383+
jnp.arange(self.hand_size) == (action % self.hand_size),
384384
jnp.zeros(self.hand_size),
385385
)
386386
actor_hand_before = old_state.player_hands[
@@ -402,7 +402,7 @@ def get_last_action_feats_(
402402

403403
# "added info token" boolean is present only when you get an info from playing the 5 of the color
404404
added_info_tokens = jnp.where(
405-
(action >= 5) & (action < 10),
405+
self._is_play(action),
406406
new_state.info_tokens.sum() > old_state.info_tokens.sum(),
407407
0,
408408
)
@@ -443,9 +443,8 @@ def get_last_action_feats(
443443

444444
return last_action
445445

446-
447446
@partial(jax.jit, static_argnums=[0])
448-
def get_board_fats(self, state: State):
447+
def get_board_feats(self, state: State):
449448
"""Get the features of the board."""
450449
# by default the fireworks are incremental, i.e. [1,1,0,0,0] one and two are in the board
451450
# must be OH of only the highest rank, i.e. [0,1,0,0,0]

0 commit comments

Comments
 (0)