Skip to content

Commit

Permalink
Time limits rework (#20)
Browse files Browse the repository at this point in the history
* rework

* simple none check

* update walkthrough

* revert xland time limit, new env regs
  • Loading branch information
Howuhh committed May 5, 2024
1 parent 756acc9 commit 991f13c
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 73 deletions.
3 changes: 0 additions & 3 deletions examples/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@
" def observation_shape(self, params: EnvParams) -> tuple[int, int, int]:\n",
" ...\n",
"\n",
" def time_limit(self, params: EnvParams) -> int:\n",
" ...\n",
"\n",
" def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:\n",
" ...\n",
"\n",
Expand Down
50 changes: 49 additions & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.7.1"
__version__ = "0.8.0"

# ---------- XLand-MiniGrid environments ----------

Expand Down Expand Up @@ -90,6 +90,14 @@
width=9,
)

register(
id="XLand-MiniGrid-R1-11x11",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
grid_type="R1",
height=11,
width=11,
)

register(
id="XLand-MiniGrid-R1-13x13",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand All @@ -98,6 +106,14 @@
width=13,
)

register(
id="XLand-MiniGrid-R1-15x15",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
grid_type="R1",
height=15,
width=15,
)

register(
id="XLand-MiniGrid-R1-17x17",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand All @@ -115,6 +131,14 @@
width=9,
)

register(
id="XLand-MiniGrid-R2-11x11",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
grid_type="R2",
height=11,
width=11,
)

register(
id="XLand-MiniGrid-R2-13x13",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand All @@ -123,6 +147,14 @@
width=13,
)

register(
id="XLand-MiniGrid-R2-15x15",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
grid_type="R2",
height=15,
width=15,
)

register(
id="XLand-MiniGrid-R2-17x17",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand All @@ -140,6 +172,14 @@
width=9,
)

register(
id="XLand-MiniGrid-R4-11x11",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
grid_type="R4",
height=11,
width=11,
)

register(
id="XLand-MiniGrid-R4-13x13",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand All @@ -148,6 +188,14 @@
width=13,
)

register(
id="XLand-MiniGrid-R4-15x15",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
grid_type="R4",
height=15,
width=15,
)

register(
id="XLand-MiniGrid-R4-17x17",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand Down
13 changes: 6 additions & 7 deletions src/xminigrid/environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import abc
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Optional, TypeVar

import jax
import jax.numpy as jnp
Expand All @@ -26,6 +26,7 @@ class EnvParams(struct.PyTreeNode):
height: int = struct.field(pytree_node=False, default=9)
width: int = struct.field(pytree_node=False, default=9)
view_size: int = struct.field(pytree_node=False, default=7)
max_steps: Optional[None] = struct.field(pytree_node=False, default=None)
render_mode: str = struct.field(pytree_node=False, default="rgb_array")


Expand All @@ -43,10 +44,6 @@ def num_actions(self, params: EnvParamsT) -> int:
def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]:
return params.view_size, params.view_size, NUM_LAYERS

# TODO: NOT sure that this should be hardcoded like that...
def time_limit(self, params: EnvParamsT) -> int:
return 3 * params.height * params.width

@abc.abstractmethod
def _generate_problem(self, params: EnvParamsT, key: jax.Array) -> State[EnvCarryT]:
...
Expand Down Expand Up @@ -76,9 +73,11 @@ def step(self, params: EnvParamsT, timestep: TimeStep[EnvCarryT], action: IntOrA

# checking for termination or truncation, choosing step type
terminated = check_goal(new_state.goal_encoding, new_state.grid, new_state.agent, action, changed_position)
truncated = jnp.equal(new_state.step_num, self.time_limit(params))

reward = jax.lax.select(terminated, 1.0 - 0.9 * (new_state.step_num / self.time_limit(params)), 0.0)
assert params.max_steps is not None
truncated = jnp.equal(new_state.step_num, params.max_steps)

reward = jax.lax.select(terminated, 1.0 - 0.9 * (new_state.step_num / params.max_steps), 0.0)

step_type = jax.lax.select(terminated | truncated, StepType.LAST, StepType.MID)
discount = jax.lax.select(terminated, jnp.asarray(0.0), jnp.asarray(1.0))
Expand Down
11 changes: 6 additions & 5 deletions src/xminigrid/envs/minigrid/blockedunlockpickup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@

class BlockedUnlockPickUp(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=6, width=11)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=6, width=11)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 16 * params.height**2
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=16 * params.height**2)
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, _key = jax.random.split(key)
Expand Down
11 changes: 6 additions & 5 deletions src/xminigrid/envs/minigrid/doorkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

class DoorKey(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=5, width=5)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=5, width=5)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 10 * (params.height * params.width)
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=10 * (params.height * params.width))
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, _key = jax.random.split(key)
Expand Down
22 changes: 12 additions & 10 deletions src/xminigrid/envs/minigrid/empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

class Empty(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=9, width=9)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=9, width=9)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 4 * (params.height * params.width)
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=4 * (params.height * params.width))
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
grid = room(params.height, params.width)
Expand All @@ -44,12 +45,13 @@ def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry

class EmptyRandom(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=9, width=9)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=9, width=9)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 4 * (params.height * params.width)
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=4 * (params.height * params.width))
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, pos_key, dir_key = jax.random.split(key, num=3)
Expand Down
12 changes: 6 additions & 6 deletions src/xminigrid/envs/minigrid/fourrooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

class FourRooms(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=19, width=19)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=19, width=19)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
# TODO: this is hardcoded and thus problematic. Move it to EnvParams?
return 100
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=100)
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=4)
Expand Down
11 changes: 6 additions & 5 deletions src/xminigrid/envs/minigrid/lockedroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@

class LockedRoom(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=19, width=19)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=19, width=19)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 10 * params.height
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=10 * params.height)
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, rooms_key, colors_key, objects_key, coords_key, agent_pos_key, agent_dir_key = jax.random.split(key, num=7)
Expand Down
15 changes: 8 additions & 7 deletions src/xminigrid/envs/minigrid/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ class MemoryEnvCarry(struct.PyTreeNode):
# TODO: Random corridor length is a bit problematic due to the dynamic slicing.
class Memory(Environment[EnvParams, MemoryEnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=7, width=13, view_size=3)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=7, width=13, view_size=3)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 5 * params.width**2
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=5 * params.width**2)
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[MemoryEnvCarry]:
key, corridor_key, agent_key, mem_key, place_key = jax.random.split(key, num=5)
Expand Down Expand Up @@ -115,14 +116,14 @@ def step(
new_state = timestep.state.replace(grid=new_grid, agent=new_agent, step_num=timestep.state.step_num + 1)
new_observation = transparent_field_of_view(new_state.grid, new_state.agent, params.view_size, params.view_size)

truncated = new_state.step_num == self.time_limit(params)
truncated = new_state.step_num == params.max_steps
terminated = jnp.logical_or(
jnp.array_equal(new_agent.position, new_state.carry.success_pos),
jnp.array_equal(new_agent.position, new_state.carry.failure_pos),
)
reward = jax.lax.select(
jnp.array_equal(new_agent.position, new_state.carry.success_pos),
1.0 - 0.9 * (new_state.step_num / self.time_limit(params)),
1.0 - 0.9 * (new_state.step_num / params.max_steps),
0.0,
)
step_type = jax.lax.select(terminated | truncated, StepType.LAST, StepType.MID)
Expand Down
9 changes: 6 additions & 3 deletions src/xminigrid/envs/minigrid/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ class PlaygroundEnvParams(EnvParams):

class Playground(Environment[PlaygroundEnvParams, EnvCarry]):
def default_params(self, **kwargs) -> PlaygroundEnvParams:
return PlaygroundEnvParams(height=19, width=19).replace(**kwargs)
params = PlaygroundEnvParams(height=19, width=19)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 512
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=512)
return params

def _generate_problem(self, params: PlaygroundEnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=6)
Expand Down
11 changes: 6 additions & 5 deletions src/xminigrid/envs/minigrid/unlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@

class Unlock(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=6, width=11)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=6, width=11)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 8 * params.height**2
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=8 * params.height**2)
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=5)
Expand Down
11 changes: 6 additions & 5 deletions src/xminigrid/envs/minigrid/unlockpickup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@

class UnlockPickUp(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=6, width=11)
default_params = default_params.replace(**kwargs)
return default_params
params = EnvParams(height=6, width=11)
params = params.replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 8 * params.height**2
if params.max_steps is None:
# formula directly taken from MiniGrid
params = params.replace(max_steps=8 * params.height**2)
return params

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=7)
Expand Down
17 changes: 9 additions & 8 deletions src/xminigrid/envs/xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,15 @@ class XLandEnvParams(EnvParams):

class XLandMiniGrid(Environment[XLandEnvParams, EnvCarry]):
def default_params(self, **kwargs) -> XLandEnvParams:
default_params = XLandEnvParams(view_size=5)
return default_params.replace(**kwargs)

def time_limit(self, params: XLandEnvParams) -> int:
# this is just a heuristic to prevent brute force in one episode,
# agent need to remember what he tried in previous episodes.
# If this is too small, change it or increase number of trials (these are not equivalent).
return 3 * (params.height * params.width)
params = XLandEnvParams(view_size=5)
params = params.replace(**kwargs)

if params.max_steps is None:
# this is just a heuristic to prevent brute force in one episode,
# so that agent need to remember what he tried in previous episodes.
# If this is too small, change it or increase number of trials (NB: these are not equivalent).
params = params.replace(max_steps=3 * (params.height * params.width))
return params

def _generate_problem(self, params: XLandEnvParams, key: jax.Array) -> State[EnvCarry]:
# WARN: we can make this compatible with jit (to vmap on different layouts during training),
Expand Down
3 changes: 0 additions & 3 deletions src/xminigrid/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def num_actions(self, params: EnvParamsT) -> int:
def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]:
return self._env.observation_shape(params)

def time_limit(self, params: EnvParamsT) -> int:
return self._env.time_limit(params)

def _generate_problem(self, params: EnvParamsT, key: jax.Array) -> State[EnvCarryT]:
return self._env._generate_problem(params, key)

Expand Down

0 comments on commit 991f13c

Please sign in to comment.