@@ -62,30 +62,25 @@ def __init__(
62
62
).squeeze ()
63
63
64
64
self .action_set = jnp .arange (self .num_moves )
65
- self .action_encoding = {}
66
- for i , a in enumerate (self .action_set ):
67
- if a < hand_size :
68
- move_type = f'D{ i % hand_size } '
69
- elif a < hand_size * 2 :
70
- move_type = f'P{ i % hand_size } '
71
- elif a < hand_size * 2 + num_colors :
72
- move_type = f'H{ self .color_map [i - hand_size * 2 ]} '
73
- elif a < hand_size * 2 + (num_colors + num_ranks ):
74
- move_type = f'H{ i - (hand_size * 2 + num_colors )+ 1 } '
75
- else :
76
- move_type = 'N'
77
- self .action_encoding [int (a )] = move_type
78
-
79
- # useful ranges to know the type of the action
80
- self .discard_action_range = jnp .arange (0 , self .hand_size )
81
- self .play_action_range = jnp .arange (self .hand_size , 2 * self .hand_size )
82
- self .color_action_range = jnp .arange (
83
- 2 * self .hand_size , 2 * self .hand_size + self .num_colors
84
- )
85
- self .rank_action_range = jnp .arange (
86
- 2 * self .hand_size + self .num_colors ,
87
- 2 * self .hand_size + self .num_colors + self .num_ranks ,
88
- )
65
+ self .action_encodings = []
66
+ for p in range (self .num_agents ):
67
+ action_encoding = {}
68
+ for i , a in enumerate (self .action_set ):
69
+ if self ._is_discard (a ):
70
+ move_type = f'D{ i % hand_size } '
71
+ elif self ._is_play (a ):
72
+ move_type = f'P{ i % hand_size } '
73
+ elif self ._is_hint_color (a ):
74
+ target_player , hint_idx = self ._get_target_player_and_hint_index (p , a )
75
+ color = self .color_map [hint_idx ]
76
+ move_type = f'H{ color } to P{ target_player } '
77
+ elif self ._is_hint_rank (a ):
78
+ target_player , hint_idx = self ._get_target_player_and_hint_index (p , a )
79
+ move_type = f'H{ hint_idx + 1 } to P{ target_player } '
80
+ else :
81
+ move_type = 'N'
82
+ action_encoding [i ] = move_type
83
+ self .action_encodings .append (action_encoding )
89
84
90
85
# number of features
91
86
self .hands_n_feats = (
@@ -136,14 +131,21 @@ def __init__(
136
131
def reset (self , key : chex .PRNGKey ) -> Tuple [Dict , State ]:
137
132
"""Reset the environment and return the initial observation."""
138
133
state = self .reset_game (key )
139
- obs = self .get_obs (state , state , action = 20 )
134
+ obs = self .get_obs (state , state , action = self . num_moves - 1 )
140
135
return obs , state
141
136
142
137
@partial (jax .jit , static_argnums = [0 ])
143
- def reset_from_deck (self , key : chex . PRNGKey , deck :chex .Array ) -> Tuple [Dict , State ]:
138
+ def reset_from_deck (self , deck : chex .Array ) -> Tuple [Dict , State ]:
144
139
"""Inject a deck in the game. Useful for testing."""
145
- state = self .reset_game_from_deck (key , deck )
146
- obs = self .get_obs (state , state , action = 20 )
140
+ state = self .reset_game_from_deck (deck )
141
+ obs = self .get_obs (state , state , action = self .num_moves - 1 )
142
+ return obs , state
143
+
144
+ @partial (jax .jit , static_argnums = [0 ])
145
+ def reset_from_deck_of_pairs (self , deck : chex .Array ) -> Tuple [Dict , State ]:
146
+ """Inject a deck from (color, rank) pairs."""
147
+ state = self .reset_game_from_deck_of_pairs (deck )
148
+ obs = self .get_obs (state , state , action = self .num_moves - 1 )
147
149
return obs , state
148
150
149
151
@partial (jax .jit , static_argnums = [0 ])
@@ -175,7 +177,7 @@ def step_env(
175
177
176
178
obs = lax .stop_gradient (self .get_obs (new_state , old_state , action ))
177
179
178
- return ( obs , lax .stop_gradient (new_state ), rewards , dones , info )
180
+ return obs , lax .stop_gradient (new_state ), rewards , dones , info
179
181
180
182
@partial (jax .jit , static_argnums = [0 ])
181
183
def get_obs (
@@ -184,7 +186,7 @@ def get_obs(
184
186
"""Get all agents' observations."""
185
187
186
188
# no agent-specific obs
187
- board_fats = self .get_board_fats (new_state )
189
+ board_fats = self .get_board_feats (new_state )
188
190
discard_feats = self ._binarize_discard_pile (new_state .discard_pile )
189
191
190
192
def _observe (aidx : int ):
@@ -242,7 +244,7 @@ def _legal_moves(aidx: int, state: State) -> chex.Array:
242
244
hands = state .player_hands
243
245
info_tokens = state .info_tokens
244
246
# discard legal when discard tokens are not full
245
- is_not_max_info_tokens = jnp .sum (state .info_tokens ) < 8
247
+ is_not_max_info_tokens = jnp .sum (state .info_tokens ) < self . max_info_tokens
246
248
legal_moves = legal_moves .at [move_idx : move_idx + self .hand_size ].set (
247
249
is_not_max_info_tokens
248
250
)
@@ -305,60 +307,58 @@ def get_last_action_feats_(
305
307
):
306
308
"""Get the features of the last action taken"""
307
309
308
- acting_player_index = old_state .cur_player_idx # absolute OH index
309
- target_player_index = new_state .cur_player_idx # absolute OH index
310
+ acting_player_index = old_state .cur_player_idx
310
311
acting_player_relative_index = jnp .roll (
311
312
acting_player_index , - aidx
312
313
) # relative OH index
314
+
315
+ target_player , hint_idx = self ._get_target_player_and_hint_index (aidx , action )
313
316
target_player_relative_index = jnp .roll (
314
- target_player_index , - aidx
317
+ jax . nn . one_hot ( target_player , num_classes = self . num_agents ) , - aidx
315
318
) # relative OH index
316
319
317
320
# in obl the encoding order here is: play, discard, reveal_c, reveal_r
318
- move_type = jnp .where ( # hard encoded but hey ho let's go
319
- (action >= 0 ) & ( action < 5 ), # discard
320
- jnp .array ([0 , 1 , 0 , 0 ]),
321
+ move_type = jnp .where (
322
+ self . _is_play (action ),
323
+ jnp .array ([1 , 0 , 0 , 0 ]),
321
324
jnp .where (
322
- (action >= 5 ) & ( action < 10 ), # play
323
- jnp .array ([1 , 0 , 0 , 0 ]),
325
+ self . _is_discard (action ),
326
+ jnp .array ([0 , 1 , 0 , 0 ]),
324
327
jnp .where (
325
- (action >= 10 ) & ( action < 15 ), # reveal_c
328
+ self . _is_hint_color (action ),
326
329
jnp .array ([0 , 0 , 1 , 0 ]),
327
330
jnp .where (
328
- (action >= 15 ) & ( action < 20 ), # reveal_r
331
+ self . _is_hint_rank (action ),
329
332
jnp .array ([0 , 0 , 0 , 1 ]),
330
- jnp .array ([0 , 0 , 0 , 0 ]), # invalid
333
+ jnp .array ([0 , 0 , 0 , 0 ]),
331
334
),
332
335
),
333
- ),
336
+ )
334
337
)
335
338
336
339
target_player_relative_index_feat = jnp .where (
337
- action >= 10 , # only for hint actions
340
+ self . _is_hint ( action ) , # only for hint actions
338
341
target_player_relative_index ,
339
342
jnp .zeros (self .num_agents ),
340
343
)
341
344
342
345
# get the hand of the target player
343
- target_hand = new_state .player_hands [
344
- jnp .nonzero (target_player_index , size = 1 )[0 ][0 ]
345
- ]
346
+ target_hand = new_state .player_hands [target_player ]
346
347
347
- color_revealed = jnp .where ( # which color was revealed by action (oh)?
348
- action == self .color_action_range ,
349
- 1.0 ,
350
- jnp .zeros (self .color_action_range . size ),
348
+ color_revealed = jnp .where (
349
+ self ._is_hint_color ( action ) ,
350
+ jax . nn . one_hot ( hint_idx , num_classes = self . num_colors ) ,
351
+ jnp .zeros (self .num_colors )
351
352
)
352
-
353
- rank_revealed = jnp .where ( # which rank was revealed by action (oh)?
354
- action == self .rank_action_range ,
355
- 1.0 ,
356
- jnp .zeros (self .rank_action_range .size ),
353
+ rank_revealed = jnp .where (
354
+ self ._is_hint_rank (action ),
355
+ jax .nn .one_hot (hint_idx , num_classes = self .num_ranks ),
356
+ jnp .zeros (self .num_ranks )
357
357
)
358
358
359
359
# cards that have the color that was revealed
360
360
color_revealed_cards = jnp .where (
361
- (target_hand .sum (axis = ( 2 ) ) == color_revealed ).all (
361
+ (target_hand .sum (axis = 2 ) == color_revealed ).all (
362
362
axis = 1
363
363
), # color of the card==color reveled
364
364
1 ,
@@ -367,7 +367,7 @@ def get_last_action_feats_(
367
367
368
368
# cards that have the color that was revealed
369
369
rank_revealed_cards = jnp .where (
370
- (target_hand .sum (axis = ( 1 ) ) == rank_revealed ).all (
370
+ (target_hand .sum (axis = 1 ) == rank_revealed ).all (
371
371
axis = 1
372
372
), # color of the card==color reveled
373
373
1 ,
@@ -379,8 +379,8 @@ def get_last_action_feats_(
379
379
380
380
# card that was played-discarded
381
381
pos_played_discarded = jnp .where (
382
- action < 2 * self .hand_size ,
383
- jnp .arange (self .hand_size ) == action % self .hand_size ,
382
+ jnp . logical_or ( self . _is_play ( action ), self ._is_discard ( action )) ,
383
+ jnp .arange (self .hand_size ) == ( action % self .hand_size ) ,
384
384
jnp .zeros (self .hand_size ),
385
385
)
386
386
actor_hand_before = old_state .player_hands [
@@ -402,7 +402,7 @@ def get_last_action_feats_(
402
402
403
403
# "added info token" boolean is present only when you get an info from playing the 5 of the color
404
404
added_info_tokens = jnp .where (
405
- (action >= 5 ) & ( action < 10 ),
405
+ self . _is_play (action ),
406
406
new_state .info_tokens .sum () > old_state .info_tokens .sum (),
407
407
0 ,
408
408
)
@@ -443,9 +443,8 @@ def get_last_action_feats(
443
443
444
444
return last_action
445
445
446
-
447
446
@partial (jax .jit , static_argnums = [0 ])
448
- def get_board_fats (self , state : State ):
447
+ def get_board_feats (self , state : State ):
449
448
"""Get the features of the board."""
450
449
# by default the fireworks are incremental, i.e. [1,1,0,0,0] one and two are in the board
451
450
# must be OH of only the highest rank, i.e. [0,1,0,0,0]
0 commit comments