From c1f2f22bc16c86dbb96b9069611c54dd95693e78 Mon Sep 17 00:00:00 2001 From: Florin Gogianu Date: Fri, 19 Jan 2024 19:02:12 +0200 Subject: [PATCH 1/3] Fix issue #3. --- src/xminigrid/types.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/xminigrid/types.py b/src/xminigrid/types.py index 326437b..c027798 100644 --- a/src/xminigrid/types.py +++ b/src/xminigrid/types.py @@ -1,3 +1,5 @@ +from dataclasses import field + import jax import jax.numpy as jnp from flax import struct @@ -17,9 +19,9 @@ class RuleSet(struct.PyTreeNode): class AgentState(struct.PyTreeNode): - position: jax.Array = jnp.asarray((0, 0)) - direction: jax.Array = jnp.asarray(0) - pocket: jax.Array = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + position: jax.Array = field(default_factory=lambda: jnp.asarray((0, 0))) + direction: jax.Array = field(default_factory=lambda: jnp.asarray(0)) + pocket: jax.Array = field(default_factory=lambda: TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY]) class EnvCarry(struct.PyTreeNode): From 948a8b28ccfa04df5834906c4ba003d33782d574 Mon Sep 17 00:00:00 2001 From: Florin Gogianu Date: Wed, 31 Jan 2024 13:59:07 +0200 Subject: [PATCH 2/3] Use the field from flax struct. --- src/xminigrid/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xminigrid/types.py b/src/xminigrid/types.py index c027798..a0b7c5c 100644 --- a/src/xminigrid/types.py +++ b/src/xminigrid/types.py @@ -19,9 +19,9 @@ class RuleSet(struct.PyTreeNode): class AgentState(struct.PyTreeNode): - position: jax.Array = field(default_factory=lambda: jnp.asarray((0, 0))) - direction: jax.Array = field(default_factory=lambda: jnp.asarray(0)) - pocket: jax.Array = field(default_factory=lambda: TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY]) + position: jax.Array = struct.field(default_factory=lambda: jnp.asarray((0, 0))) + direction: jax.Array = struct.field(default_factory=lambda: jnp.asarray(0)) + pocket: jax.Array = struct.field(default_factory=lambda: TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY]) class EnvCarry(struct.PyTreeNode): From 4b6888f3660253c29193878035324e945ed098d0 Mon Sep 17 00:00:00 2001 From: Florin Gogianu Date: Wed, 31 Jan 2024 14:05:12 +0200 Subject: [PATCH 3/3] Remove unused import. --- src/xminigrid/types.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/xminigrid/types.py b/src/xminigrid/types.py index a0b7c5c..7a69c0b 100644 --- a/src/xminigrid/types.py +++ b/src/xminigrid/types.py @@ -1,5 +1,3 @@ -from dataclasses import field - import jax import jax.numpy as jnp from flax import struct