Skip to content

Commit

Permalink
Fix issue #3. (#5)
Browse files Browse the repository at this point in the history
* Fix issue #3.

* Use the field from flax struct.

* Remove unused import.
  • Loading branch information
floringogianu authored Feb 1, 2024
1 parent abf181c commit e574386
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/xminigrid/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class RuleSet(struct.PyTreeNode):


class AgentState(struct.PyTreeNode):
position: jax.Array = jnp.asarray((0, 0))
direction: jax.Array = jnp.asarray(0)
pocket: jax.Array = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY]
position: jax.Array = struct.field(default_factory=lambda: jnp.asarray((0, 0)))
direction: jax.Array = struct.field(default_factory=lambda: jnp.asarray(0))
pocket: jax.Array = struct.field(default_factory=lambda: TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY])


class EnvCarry(struct.PyTreeNode):
Expand Down

0 comments on commit e574386

Please sign in to comment.