@@ -241,63 +241,57 @@ def get_legal_moves(self, state: State) -> chex.Array:
241
241
242
242
@partial (jax .vmap , in_axes = [0 , None ])
243
243
def _legal_moves (aidx : int , state : State ) -> chex .Array :
244
- """
245
- Legal moves encoding in order:
246
- - discard for all cards in hand
247
- - play for all cards in hand
248
- - hint for all colors and ranks for all other players
249
- """
250
- move_idx = 0
244
+ # all moves are illegal in the beginning
251
245
legal_moves = jnp .zeros (self .num_moves )
252
- hands = state .player_hands
253
- info_tokens = state .info_tokens
254
- # discard legal when discard tokens are not full
255
- is_not_max_info_tokens = jnp .sum (state .info_tokens ) < self .max_info_tokens
256
- legal_moves = legal_moves .at [move_idx : move_idx + self .hand_size ].set (
257
- is_not_max_info_tokens
258
- )
259
- move_idx += self .hand_size
260
- # play moves always legal
261
- legal_moves = legal_moves .at [move_idx : move_idx + self .hand_size ].set (1 )
262
- move_idx += self .hand_size
263
- # hints depend on other player cards
264
- other_hands = jnp .delete (hands , aidx , axis = 0 , assume_unique_indices = True )
265
-
266
- def _get_hints_for_hand (carry , unused ):
267
- """Generates valid hints given hand"""
268
- aidx , other_hands = carry
269
- hand = other_hands [aidx ]
270
-
271
- # get occurrences of each card
272
- card_counts = jnp .sum (hand , axis = 0 )
273
- # get occurrences of each color
274
- color_counts = jnp .sum (card_counts , axis = 1 )
275
- # get occurrences of each rank
276
- rank_counts = jnp .sum (card_counts , axis = 0 )
277
- # check which colors/ranks in hand
278
- colors_present = jnp .where (color_counts > 0 , 1 , 0 )
279
- ranks_present = jnp .where (rank_counts > 0 , 1 , 0 )
280
-
281
- valid_hints = jnp .concatenate ([colors_present , ranks_present ])
282
- carry = (aidx + 1 , other_hands )
283
-
284
- return carry , valid_hints
285
-
286
- _ , valid_hints = lax .scan (
287
- _get_hints_for_hand , (0 , other_hands ), None , self .num_agents - 1
246
+ all_player_hands = state .player_hands # (num_players, hand_size, num_colors, num_ranks)
247
+
248
+ # first get all cards one is holding since only these are legally playable
249
+ my_hand = all_player_hands .at [aidx ].get () # (hand_size, num_colors, num_ranks)
250
+ holding_cards_idx = jax .vmap (lambda c : jnp .any (c ))(my_hand )
251
+
252
+ # discard is legal when tokens are not full and when card is in hand
253
+ can_get_token = jnp .sum (state .info_tokens ) < self .max_info_tokens
254
+ legal_discard_idx = holding_cards_idx * can_get_token
255
+ legal_moves = legal_moves .at [self .discard_action_range ].set (legal_discard_idx )
256
+
257
+ # play is legal for cards in hand - if empty card, not legal.
258
+ legal_moves = legal_moves .at [self .play_action_range ].set (holding_cards_idx )
259
+
260
+ # hints depend on cards held by other players and not our own
261
+ other_players_hands = jnp .delete (
262
+ all_player_hands , aidx , axis = 0 , assume_unique_indices = True
288
263
)
289
- # make other player positions relative to current player
290
- valid_hints = jnp .roll (valid_hints , - aidx , axis = 0 )
291
- # include valid hints in legal moves
292
- num_hints = (self .num_agents - 1 ) * (self .num_colors + self .num_ranks )
293
- valid_hints = jnp .concatenate (valid_hints , axis = 0 )
294
- info_tokens_available = jnp .sum (info_tokens ) != 0
295
- valid_hints *= info_tokens_available
296
- legal_moves = legal_moves .at [move_idx : move_idx + num_hints ].set (
297
- valid_hints
264
+ # adjust to have relative positions
265
+ other_players_hands = jnp .roll (
266
+ other_players_hands , - aidx , axis = 0
298
267
)
299
268
300
- # only enable noop if not current player
269
+ # cards can be hinted only if info tokens are available
270
+ info_tokens_available = jnp .sum (state .info_tokens ) > 0
271
+
272
+ # get all the colors that can be hinted
273
+ def _hintable_colors (hand ):
274
+ # Hand: (num_cards, num_colors, num_ranks)
275
+ card_colors = jnp .sum (hand , axis = 2 )
276
+ hintable_colors = card_colors .any (axis = 0 )
277
+ return hintable_colors
278
+
279
+ legal_color_hints = jax .vmap (_hintable_colors )(other_players_hands ).ravel ()
280
+ legal_color_hints = legal_color_hints * info_tokens_available
281
+ legal_moves = legal_moves .at [self .color_action_range ].set (legal_color_hints )
282
+
283
+ # get all the ranks that can be hinted.
284
+ def _hintable_ranks (hand ):
285
+ # Hand: (num_cards, num_colors, num_ranks)
286
+ card_ranks = jnp .sum (hand , axis = 1 )
287
+ hintable_ranks = card_ranks .any (axis = 0 )
288
+ return hintable_ranks
289
+
290
+ legal_rank_hints = jax .vmap (_hintable_ranks )(other_players_hands ).ravel ()
291
+ legal_rank_hints = legal_rank_hints * info_tokens_available
292
+ legal_moves = legal_moves .at [self .rank_action_range ].set (legal_rank_hints )
293
+
294
+ # Only legalize noop if not current player.
301
295
cur_player = jnp .nonzero (state .cur_player_idx , size = 1 )[0 ][0 ]
302
296
not_cur_player = aidx != cur_player
303
297
legal_moves -= legal_moves * not_cur_player
0 commit comments