diff --git a/src/xminigrid/types.py b/src/xminigrid/types.py index 5db9411..edab08a 100644 --- a/src/xminigrid/types.py +++ b/src/xminigrid/types.py @@ -18,9 +18,9 @@ class RuleSet(struct.PyTreeNode): class AgentState(struct.PyTreeNode): - position: jax.Array = jnp.asarray((0, 0)) - direction: jax.Array = jnp.asarray(0) - pocket: jax.Array = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + position: jax.Array = struct.field(default_factory=lambda: jnp.asarray((0, 0))) + direction: jax.Array = struct.field(default_factory=lambda: jnp.asarray(0)) + pocket: jax.Array = struct.field(default_factory=lambda: TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY]) class EnvCarry(struct.PyTreeNode):