Skip to content

Commit

Permalink
general improvements (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh authored May 1, 2024
1 parent b21c142 commit 756acc9
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.7.0"
__version__ = "0.7.1"

# ---------- XLand-MiniGrid environments ----------

Expand Down
1 change: 1 addition & 0 deletions src/xminigrid/core/goals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .grid import equal, get_neighbouring_tiles, pad_along_axis

MAX_GOAL_ENCODING_LEN = 4 + 1 # for idx
NUM_GOALS = 15


def check_goal(
Expand Down
1 change: 1 addition & 0 deletions src/xminigrid/core/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .grid import equal, get_neighbouring_tiles, pad_along_axis

MAX_RULE_ENCODING_LEN = 6 + 1 # +1 for idx
NUM_RULES = 12


# this is very costly, will evaluate all rules under vmap. Submit a PR if you know how to do it better!
Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/envs/minigrid/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def step(
self, params: EnvParams, timestep: TimeStep[MemoryEnvCarry], action: IntOrArray
) -> TimeStep[MemoryEnvCarry]:
# disabling pick_up action
action = jax.lax.select(jnp.equal(action, 3), 5, action)
action = jax.lax.select(jnp.equal(action, 3), jnp.asarray(5, dtype=jnp.uint8), action)
new_grid, new_agent, _ = take_action(timestep.state.grid, timestep.state.agent, action)

new_state = timestep.state.replace(grid=new_grid, agent=new_agent, step_num=timestep.state.step_num + 1)
Expand Down
20 changes: 8 additions & 12 deletions src/xminigrid/manual_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,10 @@ def start(self) -> None:
def step(self, action: int) -> None:
self.timestep = self._step(self.env_params, self.timestep, action)
print(
"StepType: ",
self.timestep.step_type,
"Discount: ",
self.timestep.discount,
"Reward: ",
self.timestep.reward,
f"Step: {self.timestep.state.step_num} | ",
f"StepType: {self.timestep.step_type} | ",
f"Discount: {self.timestep.discount} | ",
f"Reward: {self.timestep.reward}",
)
self.render()

Expand All @@ -135,12 +133,10 @@ def reset(self) -> None:
self.timestep = self._reset(self.env_params, reset_key)
self.render()
print(
"StepType: ",
self.timestep.step_type,
"Discount: ",
self.timestep.discount,
"Reward: ",
self.timestep.reward,
f"Step: {self.timestep.state.step_num} |",
f"StepType: {self.timestep.step_type} |",
f"Discount: {self.timestep.discount} |",
f"Reward: {self.timestep.reward}",
)

def key_handler(self, event: Event) -> None:
Expand Down
4 changes: 2 additions & 2 deletions training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

def main():
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint = orbax_checkpointer.restore("../xland-minigrid-data/XLand-MiniGrid-R1-9x9-Trivial-v0-5B-gamma99")
checkpoint = orbax_checkpointer.restore("../xland-minigrid-data/checkpoints")
config = checkpoint["config"]
params = checkpoint["params"]

env, env_params = xminigrid.make("XLand-MiniGrid-R1-9x9")
env = GymAutoResetWrapper(env)

ruleset = xminigrid.load_benchmark("Trivial").get_ruleset(3)
ruleset = xminigrid.load_benchmark("trivial-1m").get_ruleset(3)
env_params = env_params.replace(ruleset=ruleset)

model = ActorCriticRNN(
Expand Down

0 comments on commit 756acc9

Please sign in to comment.