Skip to content

Commit

Permalink
solved discard pile bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Feb 8, 2024
1 parent f54a375 commit 3765afc
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions jaxmarl/environments/hanabi/hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,6 @@ def _discard_play_fn(state, action):
).astype(int)
card = hand_before.at[card_idx].get()

# discard selected card if discard action
discard_card = jnp.zeros_like(card) + (is_discard * card)
discard_pile = state.discard_pile.at[state.num_cards_discarded].set(discard_card)
num_cards_discarded = state.num_cards_discarded + is_discard

# gain an info token for discarding if discard action
infos_remaining = jnp.sum(state.info_tokens)
infos_depleted = (infos_remaining < self.max_info_tokens)
Expand All @@ -368,11 +363,11 @@ def _discard_play_fn(state, action):
color_fireworks = color_fireworks.at[0, jnp.sum(color_fireworks).astype(int)].set(make_play)
fireworks = state.fireworks.at[color].set(color_fireworks)

# discard if play action was invalid
failed_play = jnp.logical_and(jnp.logical_not(is_valid_play), jnp.logical_not(is_discard)).squeeze(0)
discard_card = jnp.zeros_like(card) + (failed_play * card)
discard_pile = state.discard_pile.at[state.num_cards_discarded].set(discard_card)
num_cards_discarded = state.num_cards_discarded + failed_play
# the card must be discarded if action is discard or the play action is not valid
discard_card = ((~is_valid_play)|(is_discard)).squeeze(0)
discarded_card = jnp.zeros_like(card) + (discard_card * card)
discard_pile = state.discard_pile.at[state.num_cards_discarded].set(discarded_card)
num_cards_discarded = state.num_cards_discarded + discard_card

# remove life token if invalid play
life_lost = jnp.logical_and(jnp.logical_not(is_valid_play),
Expand Down

0 comments on commit 3765afc

Please sign in to comment.