diff --git a/jaxmarl/environments/hanabi/hanabi.py b/jaxmarl/environments/hanabi/hanabi.py index f216f24a..aaf2e799 100644 --- a/jaxmarl/environments/hanabi/hanabi.py +++ b/jaxmarl/environments/hanabi/hanabi.py @@ -237,7 +237,7 @@ def _observe(aidx: int): return {a: obs[i] for i, a in enumerate(self.agents)} def get_legal_moves(self, state: State) -> chex.Array: - # Play is legal when card is in hand. + """Get all agents' legal moves""" @partial(jax.vmap, in_axes=[0, None]) def _legal_moves(aidx: int, state: State) -> chex.Array: