From 4a24b157e7b65cf19d6db2cbb148bee11c558d66 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Fri, 1 Mar 2024 22:36:16 +0300 Subject: [PATCH 01/14] image obs rendering --- src/xminigrid/__init__.py | 2 +- src/xminigrid/core/constants.py | 66 ++++++++----------- src/xminigrid/core/grid.py | 4 +- src/xminigrid/core/observation.py | 6 +- src/xminigrid/envs/xland.py | 1 - src/xminigrid/experimental/img_obs.py | 90 ++++++++++++++++++++++++++ src/xminigrid/manual_control.py | 16 ++++- src/xminigrid/rendering/rgb_render.py | 17 ++++- src/xminigrid/rendering/text_render.py | 4 -- src/xminigrid/wrappers.py | 6 ++ 10 files changed, 157 insertions(+), 55 deletions(-) create mode 100644 src/xminigrid/experimental/img_obs.py diff --git a/src/xminigrid/__init__.py b/src/xminigrid/__init__.py index 2e0abaa..a67179d 100644 --- a/src/xminigrid/__init__.py +++ b/src/xminigrid/__init__.py @@ -2,7 +2,7 @@ from .registration import make, register, registered_environments # TODO: add __all__ -__version__ = "0.6.0" +__version__ = "0.7.0" # ---------- XLand-MiniGrid environments ---------- diff --git a/src/xminigrid/core/constants.py b/src/xminigrid/core/constants.py index 1d36d73..f9b9241 100644 --- a/src/xminigrid/core/constants.py +++ b/src/xminigrid/core/constants.py @@ -1,48 +1,44 @@ import jax.numpy as jnp from flax import struct +NUM_ACTIONS = 6 + # GRID: [tile, color] NUM_LAYERS = 2 -NUM_TILES = 15 -NUM_COLORS = 14 -NUM_ACTIONS = 6 +NUM_TILES = 13 +NUM_COLORS = 12 -# TODO: do we really need END_OF_MAP? seem like unseen can be used instead... # enums, kinda... class Tiles(struct.PyTreeNode): EMPTY: int = struct.field(pytree_node=False, default=0) - END_OF_MAP: int = struct.field(pytree_node=False, default=1) - UNSEEN: int = struct.field(pytree_node=False, default=2) - FLOOR: int = struct.field(pytree_node=False, default=3) - WALL: int = struct.field(pytree_node=False, default=4) - BALL: int = struct.field(pytree_node=False, default=5) - SQUARE: int = struct.field(pytree_node=False, default=6) - PYRAMID: int = struct.field(pytree_node=False, default=7) - GOAL: int = struct.field(pytree_node=False, default=8) - KEY: int = struct.field(pytree_node=False, default=9) - DOOR_LOCKED: int = struct.field(pytree_node=False, default=10) - DOOR_CLOSED: int = struct.field(pytree_node=False, default=11) - DOOR_OPEN: int = struct.field(pytree_node=False, default=12) - HEX: int = struct.field(pytree_node=False, default=13) - STAR: int = struct.field(pytree_node=False, default=14) + FLOOR: int = struct.field(pytree_node=False, default=1) + WALL: int = struct.field(pytree_node=False, default=2) + BALL: int = struct.field(pytree_node=False, default=3) + SQUARE: int = struct.field(pytree_node=False, default=4) + PYRAMID: int = struct.field(pytree_node=False, default=5) + GOAL: int = struct.field(pytree_node=False, default=6) + KEY: int = struct.field(pytree_node=False, default=7) + DOOR_LOCKED: int = struct.field(pytree_node=False, default=8) + DOOR_CLOSED: int = struct.field(pytree_node=False, default=9) + DOOR_OPEN: int = struct.field(pytree_node=False, default=10) + HEX: int = struct.field(pytree_node=False, default=11) + STAR: int = struct.field(pytree_node=False, default=12) class Colors(struct.PyTreeNode): EMPTY: int = struct.field(pytree_node=False, default=0) - END_OF_MAP: int = struct.field(pytree_node=False, default=1) - UNSEEN: int = struct.field(pytree_node=False, default=2) - RED: int = struct.field(pytree_node=False, default=3) - GREEN: int = struct.field(pytree_node=False, default=4) - BLUE: int = struct.field(pytree_node=False, default=5) - PURPLE: int = struct.field(pytree_node=False, default=6) - YELLOW: int = struct.field(pytree_node=False, default=7) - GREY: int = struct.field(pytree_node=False, default=8) - BLACK: int = struct.field(pytree_node=False, default=9) - ORANGE: int = struct.field(pytree_node=False, default=10) - WHITE: int = struct.field(pytree_node=False, default=11) - BROWN: int = struct.field(pytree_node=False, default=12) - PINK: int = struct.field(pytree_node=False, default=13) + RED: int = struct.field(pytree_node=False, default=1) + GREEN: int = struct.field(pytree_node=False, default=2) + BLUE: int = struct.field(pytree_node=False, default=3) + PURPLE: int = struct.field(pytree_node=False, default=4) + YELLOW: int = struct.field(pytree_node=False, default=5) + GREY: int = struct.field(pytree_node=False, default=6) + BLACK: int = struct.field(pytree_node=False, default=7) + ORANGE: int = struct.field(pytree_node=False, default=8) + WHITE: int = struct.field(pytree_node=False, default=9) + BROWN: int = struct.field(pytree_node=False, default=10) + PINK: int = struct.field(pytree_node=False, default=11) # Only ~100 combinations so far, better to preallocate them @@ -65,7 +61,6 @@ class Colors(struct.PyTreeNode): WALKABLE = jnp.array( ( - Tiles.EMPTY, Tiles.FLOOR, Tiles.GOAL, Tiles.DOOR_OPEN, @@ -83,12 +78,7 @@ class Colors(struct.PyTreeNode): ) ) -FREE_TO_PUT_DOWN = jnp.array( - ( - Tiles.EMPTY, - Tiles.FLOOR, - ) -) +FREE_TO_PUT_DOWN = jnp.array((Tiles.FLOOR,)) LOS_BLOCKING = jnp.array( ( diff --git a/src/xminigrid/core/grid.py b/src/xminigrid/core/grid.py index 803b6c7..b3152bd 100644 --- a/src/xminigrid/core/grid.py +++ b/src/xminigrid/core/grid.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp @@ -22,7 +22,7 @@ def equal(tile1: Tile, tile2: Tile) -> Tile: def get_neighbouring_tiles(grid: GridState, y: IntOrArray, x: IntOrArray) -> tuple[Tile, Tile, Tile, Tile]: # end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP] - end_of_map = Tiles.END_OF_MAP + end_of_map = Tiles.EMPTY up_tile = grid.at[y - 1, x].get(mode="fill", fill_value=end_of_map) right_tile = grid.at[y, x + 1].get(mode="fill", fill_value=end_of_map) diff --git a/src/xminigrid/core/observation.py b/src/xminigrid/core/observation.py index e2c4ab8..e8ea031 100644 --- a/src/xminigrid/core/observation.py +++ b/src/xminigrid/core/observation.py @@ -12,7 +12,7 @@ def crop_field_of_view(grid: GridState, agent: AgentState, height: int, width: i grid = jnp.pad( grid, pad_width=((height, height), (width, width), (0, 0)), - constant_values=Tiles.END_OF_MAP, + constant_values=Tiles.EMPTY, ) # account for padding y = agent.position[0] + height @@ -110,8 +110,8 @@ def minigrid_field_of_view(grid: GridState, agent: AgentState, height: int, widt fov_grid = crop_field_of_view(grid, agent, height, width) fov_grid = align_with_up(fov_grid, agent.direction) mask = generate_viz_mask_minigrid(fov_grid) - # set UNSEEN value for all layers (including colors, as UNSEEN color has same id value) - fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.UNSEEN) + # set EMPTY as unseen value for all layers (including colors, as EMPTY color has same id value) + fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.EMPTY) # TODO: should we even do this? Agent with good memory can remember what he picked up. # WARN: this can overwrite tile the agent is on, GOAL for example. diff --git a/src/xminigrid/envs/xland.py b/src/xminigrid/envs/xland.py index 4acac15..9092005 100644 --- a/src/xminigrid/envs/xland.py +++ b/src/xminigrid/envs/xland.py @@ -27,7 +27,6 @@ init_tiles=jnp.array(((TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY],))), ) -_empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] _wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] # colors for doors between rooms _allowed_colors = jnp.array( diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py new file mode 100644 index 0000000..c8c28d3 --- /dev/null +++ b/src/xminigrid/experimental/img_obs.py @@ -0,0 +1,90 @@ +# jit-compatible RGB observations. Currently experimental! +# if it proves useful and necessary in the future, I will consider rewriting env.render in such style also +from __future__ import annotations + +import os + +import jax +import jax.numpy as jnp +import numpy as np + +from ..benchmarks import load_bz2_pickle, save_bz2_pickle +from ..core.constants import NUM_COLORS, NUM_LAYERS, TILES_REGISTRY +from ..rendering.rgb_render import render_tile +from ..wrappers import Wrapper + +CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid")) + + +def build_cache(tiles: np.ndarray, tile_size: float = 32) -> tuple[np.ndarray, np.ndarray]: + cache = np.full((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8, fill_value=-1) + agent_cache = np.full((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8, fill_value=-1) + + for y in range(tiles.shape[0]): + for x in range(tiles.shape[1]): + # rendering tile + tile_img = render_tile( + tile=tuple(tiles[y, x]), + agent_direction=None, + highlight=False, + tile_size=int(tile_size), + ) + cache[y, x] = tile_img + + # rendering agent on top + tile_w_agent_img = render_tile( + tile=tuple(tiles[y, x]), + agent_direction=0, + highlight=False, + tile_size=int(tile_size), + ) + agent_cache[y, x] = tile_w_agent_img + + return cache, agent_cache + + +# building cache of pre-rendered tiles +TILE_SIZE = 32 + +cache_path = os.path.join(CACHE_PATH, "render_cache") +if not os.path.exists(cache_path): + TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE) + TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) + TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) + + save_bz2_pickle({"tile_cache": TILE_CACHE, "tile_agent_cache": TILE_W_AGENT_CACHE}, cache_path) + +TILE_CACHE = load_bz2_pickle(cache_path)["tile_cache"] +TILE_W_AGENT_CACHE = load_bz2_pickle(cache_path)["tile_agent_cache"] + + +# rendering with cached tiles +def _render_obs(obs: jax.Array) -> jax.Array: + view_size = obs.shape[0] + + obs_flat_idxs = obs[:, :, 0] * NUM_COLORS + obs[:, :, 1] + # render all tiles + rendered_obs = jnp.take(TILE_CACHE, obs_flat_idxs, axis=0) + + # add agent tile + agent_tile = TILE_W_AGENT_CACHE[obs_flat_idxs[view_size - 1, view_size // 2]] + rendered_obs = rendered_obs.at[view_size - 1, view_size // 2].set(agent_tile) + # [view_size, view_size, tile_size, tile_size, 3] -> [view_size * tile_size, view_size * tile_size, 3] + rendered_obs = rendered_obs.transpose((0, 2, 1, 3, 4)).reshape(view_size * TILE_SIZE, view_size * TILE_SIZE, 3) + + return rendered_obs + + +class RGBImgObservationWrapper(Wrapper): + def observation_shape(self, params): + return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, NUM_LAYERS + + def reset(self, params, key): + timestep = self._env.reset(params, key) + timestep = timestep.replace(observation=_render_obs(timestep.observation)) + return timestep + + def step(self, params, timestep, action): + timestep = self._env.step(params, timestep, action) + timestep = timestep.replace(observation=_render_obs(timestep.observation)) + return timestep diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 1dcd31f..1c58e76 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -14,9 +14,10 @@ class ManualControl: - def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParamsT): + def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParamsT, agent_view: bool = False): self.env = env self.env_params = env_params + self.agent_view = agent_view self._reset = jax.jit(self.env.reset) self._step = jax.jit(self.env.step) @@ -33,7 +34,10 @@ def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParam def render(self) -> None: assert self.timestep is not None - img = self.env.render(self.env_params, self.timestep) + if self.agent_view: + img = self.timestep.observation + else: + img = self.env.render(self.env_params, self.timestep) # [h, w, c] -> [w, h, c] img = np.transpose(img, axes=(1, 0, 2)) @@ -144,14 +148,20 @@ def close(self) -> None: parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-5x5", choices=xminigrid.registered_environments()) parser.add_argument("--benchmark-id", type=str, default="trivial-1m", choices=xminigrid.registered_benchmarks()) parser.add_argument("--ruleset-id", type=int, default=0) + parser.add_argument("--agent-view", action="store_true") args = parser.parse_args() env, env_params = xminigrid.make(args.env_id) env = GymAutoResetWrapper(env) + if args.agent_view: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + if "XLand" in args.env_id: bench = xminigrid.load_benchmark(args.benchmark_id) env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id)) - control = ManualControl(env=env, env_params=env_params) + control = ManualControl(env=env, env_params=env_params, agent_view=args.agent_view) control.start() diff --git a/src/xminigrid/rendering/rgb_render.py b/src/xminigrid/rendering/rgb_render.py index 5c8ea81..0343a93 100644 --- a/src/xminigrid/rendering/rgb_render.py +++ b/src/xminigrid/rendering/rgb_render.py @@ -5,7 +5,7 @@ import numpy as np from ..core.constants import Colors, Tiles -from ..types import AgentState, GridState, IntOrArray +from ..types import AgentState, IntOrArray from .utils import ( downsample, fill_coords, @@ -18,6 +18,7 @@ ) COLORS_MAP = { + Colors.EMPTY: np.array((255, 255, 255)), # just a placeholder Colors.RED: np.array((255, 0, 0)), Colors.GREEN: np.array((0, 255, 0)), Colors.BLUE: np.array((0, 0, 255)), @@ -32,6 +33,16 @@ } +def _render_empty(img: np.ndarray, color: int): + fill_coords(img, point_in_rect(0.45, 0.55, 0.2, 0.65), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(0.45, 0.55, 0.7, 0.85), COLORS_MAP[Colors.RED]) + + fill_coords(img, point_in_rect(0, 0.031, 0, 1), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(0, 1, 0, 0.031), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(1 - 0.031, 1, 0, 1), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(0, 1, 1 - 0.031, 1), COLORS_MAP[Colors.RED]) + + def _render_floor(img: np.ndarray, color: int): # draw the grid lines (top and left edges) fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100)) @@ -165,7 +176,7 @@ def _render_player(img: np.ndarray, direction: int): Tiles.DOOR_LOCKED: _render_door_locked, Tiles.DOOR_CLOSED: _render_door_closed, Tiles.DOOR_OPEN: _render_door_open, - Tiles.EMPTY: lambda img, color: img, + Tiles.EMPTY: _render_empty, } @@ -196,7 +207,7 @@ def get_highlight_mask(grid: np.ndarray, agent: AgentState | None, view_size: in @functools.cache def render_tile( - tile: np.ndarray, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3 + tile: tuple, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3 ) -> np.ndarray: img = np.full((tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8, fill_value=255) # draw tile diff --git a/src/xminigrid/rendering/text_render.py b/src/xminigrid/rendering/text_render.py index 8f3887c..8b7c514 100644 --- a/src/xminigrid/rendering/text_render.py +++ b/src/xminigrid/rendering/text_render.py @@ -5,8 +5,6 @@ from ..types import AgentState, RuleSet COLOR_NAMES = { - Colors.END_OF_MAP: "red", - Colors.UNSEEN: "white", Colors.EMPTY: "white", Colors.RED: "red", Colors.GREEN: "green", @@ -22,8 +20,6 @@ } TILE_STR = { - Tiles.END_OF_MAP: "!", - Tiles.UNSEEN: "?", Tiles.EMPTY: " ", Tiles.FLOOR: ".", Tiles.WALL: "☰", diff --git a/src/xminigrid/wrappers.py b/src/xminigrid/wrappers.py index 15e9810..dcfa9f1 100644 --- a/src/xminigrid/wrappers.py +++ b/src/xminigrid/wrappers.py @@ -16,6 +16,12 @@ def __init__(self, env: Environment[EnvParamsT, EnvCarryT]): def default_params(self, **kwargs) -> EnvParamsT: return self._env.default_params(**kwargs) + def num_actions(self, params: EnvParamsT) -> int: + return self._env.num_actions(params) + + def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]: + return self._env.observation_shape(params) + def time_limit(self, params: EnvParamsT) -> int: return self._env.time_limit(params) From 26b72027093443766023c9c8ecc254737bb0dd46 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Sat, 2 Mar 2024 10:16:33 +0300 Subject: [PATCH 02/14] revert pre-commit --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 308b33e..057a42b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,4 @@ repos: rev: v1.1.350 hooks: - id: pyright -# args: [--project=pyproject.toml] \ No newline at end of file + args: [--project=pyproject.toml] \ No newline at end of file From 50d61b57fb9de5e928170ac5982ddff0a302eea0 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Sat, 2 Mar 2024 10:19:39 +0300 Subject: [PATCH 03/14] fix types --- src/xminigrid/experimental/img_obs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index c8c28d3..f4ee79b 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -16,9 +16,9 @@ CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid")) -def build_cache(tiles: np.ndarray, tile_size: float = 32) -> tuple[np.ndarray, np.ndarray]: - cache = np.full((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8, fill_value=-1) - agent_cache = np.full((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8, fill_value=-1) +def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.ndarray]: + cache = np.zeros((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8) + agent_cache = np.zeros((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8) for y in range(tiles.shape[0]): for x in range(tiles.shape[1]): From 2b0f87bd0713890b97559cf1d92bef18c0bf3d1a Mon Sep 17 00:00:00 2001 From: Howuhh Date: Sat, 2 Mar 2024 13:25:59 +0300 Subject: [PATCH 04/14] img obs benchmarking, fixed cache bug --- scripts/benchmark_xland.py | 24 ++++++++++++++++++++---- scripts/benchmark_xland_all.py | 16 ++++++++++++++-- scripts/generate_benchmarks.sh | 1 - src/xminigrid/__init__.py | 3 ++- src/xminigrid/experimental/img_obs.py | 2 ++ 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/scripts/benchmark_xland.py b/scripts/benchmark_xland.py index a7bd68b..5af3742 100644 --- a/scripts/benchmark_xland.py +++ b/scripts/benchmark_xland.py @@ -15,15 +15,29 @@ parser = argparse.ArgumentParser() parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-16x16") parser.add_argument("--benchmark-id", type=str, default="Trivial") +parser.add_argument("--img-obs", action="store_true") parser.add_argument("--timesteps", type=int, default=1000) parser.add_argument("--num-envs", type=int, default=8192) parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats") parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)") -def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None): +def build_benchmark( + env_id: str, + num_envs: int, + timesteps: int, + benchmark_id: Optional[str] = None, + img_obs: bool = False, +): env, env_params = xminigrid.make(env_id) env = GymAutoResetWrapper(env) + + # enable img observations if needed + if img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # choose XLand benchmark if needed if "XLand-MiniGrid" in env_id and benchmark_id is not None: ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) @@ -73,13 +87,15 @@ def timeit_benchmark(args, benchmark_fn): print("Num devices for pmap:", num_devices) # building for single env benchmarking - benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id) + benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id, args.img_obs) benchmark_fn_single = jax.jit(benchmark_fn_single) # building vmap for vectorization benchmarking - benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id) + benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id, args.img_obs) benchmark_fn_vmap = jax.jit(benchmark_fn_vmap) # building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps) - benchmark_fn_pmap = build_benchmark(args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id) + benchmark_fn_pmap = build_benchmark( + args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs + ) benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) key = jax.random.PRNGKey(0) diff --git a/scripts/benchmark_xland_all.py b/scripts/benchmark_xland_all.py index 3474820..cae74f9 100644 --- a/scripts/benchmark_xland_all.py +++ b/scripts/benchmark_xland_all.py @@ -18,14 +18,24 @@ parser = argparse.ArgumentParser() parser.add_argument("--benchmark-id", type=str, default="trivial-1m") +parser.add_argument("--img-obs", action="store_true") parser.add_argument("--timesteps", type=int, default=1000) parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats") parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)") -def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None): +def build_benchmark( + env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None, img_obs: bool = False +): env, env_params = xminigrid.make(env_id) env = GymAutoResetWrapper(env) + + # enable img observations if needed + if img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # choose XLand benchmark if needed if "XLand-MiniGrid" in env_id and benchmark_id is not None: ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) @@ -77,7 +87,9 @@ def timeit_benchmark(args, benchmark_fn): for env_id in tqdm(environments, desc="Envs.."): assert num_envs % num_devices == 0 # building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps) - benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id) + benchmark_fn_pmap = build_benchmark( + env_id, num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs + ) benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) # benchmarking diff --git a/scripts/generate_benchmarks.sh b/scripts/generate_benchmarks.sh index 9593efb..b2811a6 100644 --- a/scripts/generate_benchmarks.sh +++ b/scripts/generate_benchmarks.sh @@ -7,7 +7,6 @@ python scripts/ruleset_generator.py \ --total_rulesets=1_000_000 \ --save_path="trivial_1m" - # small python scripts/ruleset_generator.py \ --prune_chain \ diff --git a/src/xminigrid/__init__.py b/src/xminigrid/__init__.py index a67179d..a2f2af1 100644 --- a/src/xminigrid/__init__.py +++ b/src/xminigrid/__init__.py @@ -210,7 +210,8 @@ # BlockedUnlockPickUp register( - id="MiniGrid-BlockedUnlockPickUp", entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp" + id="MiniGrid-BlockedUnlockPickUp", + entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp", ) # DoorKey diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index f4ee79b..7ac3553 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -48,6 +48,8 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np. cache_path = os.path.join(CACHE_PATH, "render_cache") if not os.path.exists(cache_path): + os.makedirs(CACHE_PATH, exist_ok=True) + TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE) TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) From 92a54ffa0eb934eaad57ff25c3d758bebc34bef1 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Mon, 4 Mar 2024 17:36:43 +0300 Subject: [PATCH 05/14] regenerated benchmarks --- scripts/generate_benchmarks.sh | 21 ++++++++++----------- src/xminigrid/benchmarks.py | 13 ++++++------- src/xminigrid/experimental/img_obs.py | 3 ++- src/xminigrid/manual_control.py | 5 ++++- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/scripts/generate_benchmarks.sh b/scripts/generate_benchmarks.sh index b2811a6..47e9528 100644 --- a/scripts/generate_benchmarks.sh +++ b/scripts/generate_benchmarks.sh @@ -40,17 +40,16 @@ python scripts/ruleset_generator.py \ --total_rulesets=1_000_000 \ --save_path="high_1m" - -# medium + distractors -python scripts/ruleset_generator.py \ - --prune_chain \ - --prune_prob=0.8 \ - --chain_depth=2 \ - --sample_distractor_rules \ - --num_distractor_rules=4 \ - --num_distractor_objects=2 \ - --total_rulesets=1_000_000 \ - --save_path="medium_dist_1m" +## medium + distractors +#python scripts/ruleset_generator.py \ +# --prune_chain \ +# --prune_prob=0.8 \ +# --chain_depth=2 \ +# --sample_distractor_rules \ +# --num_distractor_rules=4 \ +# --num_distractor_objects=2 \ +# --total_rulesets=1_000_000 \ +# --save_path="medium_dist_1m" # medium 3M python scripts/ruleset_generator.py \ diff --git a/src/xminigrid/benchmarks.py b/src/xminigrid/benchmarks.py index 40b0641..2120d76 100644 --- a/src/xminigrid/benchmarks.py +++ b/src/xminigrid/benchmarks.py @@ -18,13 +18,12 @@ DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid")) NAME2HFFILENAME = { - "trivial-1m": "trivial_1m", - "small-1m": "small_1m", - "small-dist-1m": "small_dist_1m", - "medium-1m": "medium_1m_v1", - "medium-3m": "medium_3m_v1", - "high-1m": "high_1m", - "high-3m": "high_3m", + "trivial-1m": "trivial_1m_v2", + "small-1m": "small_1m_v2", + "medium-1m": "medium_1m_v1_v2", + "medium-3m": "medium_3m_v1_v2", + "high-1m": "high_1m_v2", + "high-3m": "high_3m_v2", } diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index 7ac3553..a90f315 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -49,11 +49,12 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np. cache_path = os.path.join(CACHE_PATH, "render_cache") if not os.path.exists(cache_path): os.makedirs(CACHE_PATH, exist_ok=True) - + print("Building rendering cache, may take a while...") TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE) TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) + print("Done. Cache will be reused on consequent runs.") save_bz2_pickle({"tile_cache": TILE_CACHE, "tile_agent_cache": TILE_W_AGENT_CACHE}, cache_path) TILE_CACHE = load_bz2_pickle(cache_path)["tile_cache"] diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 1c58e76..810ca67 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -7,10 +7,11 @@ from pygame.event import Event import xminigrid -from xminigrid.wrappers import GymAutoResetWrapper from .environment import Environment, EnvParamsT +from .rendering.text_render import print_ruleset from .types import EnvCarryT +from .wrappers import GymAutoResetWrapper class ManualControl: @@ -162,6 +163,8 @@ def close(self) -> None: if "XLand" in args.env_id: bench = xminigrid.load_benchmark(args.benchmark_id) env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id)) + print_ruleset(env_params.ruleset) + print() control = ManualControl(env=env, env_params=env_params, agent_view=args.agent_view) control.start() From 055cc63879b3fbdb0d12ccad886af789e0664e0c Mon Sep 17 00:00:00 2001 From: Howuhh Date: Mon, 4 Mar 2024 22:26:54 +0300 Subject: [PATCH 06/14] add check --- src/xminigrid/manual_control.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 810ca67..a944869 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -163,6 +163,7 @@ def close(self) -> None: if "XLand" in args.env_id: bench = xminigrid.load_benchmark(args.benchmark_id) env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id)) + assert hasattr(env_params, "ruleset") print_ruleset(env_params.ruleset) print() From d0a2316bacdd57c7d0ece3febab39da8d5f82372 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Mon, 4 Mar 2024 22:37:09 +0300 Subject: [PATCH 07/14] remove check --- src/xminigrid/manual_control.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index a944869..95185f5 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -9,6 +9,7 @@ import xminigrid from .environment import Environment, EnvParamsT +from .envs.xland import XLandEnvParams from .rendering.text_render import print_ruleset from .types import EnvCarryT from .wrappers import GymAutoResetWrapper @@ -162,9 +163,10 @@ def close(self) -> None: if "XLand" in args.env_id: bench = xminigrid.load_benchmark(args.benchmark_id) - env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id)) - assert hasattr(env_params, "ruleset") - print_ruleset(env_params.ruleset) + ruleset = bench.get_ruleset(args.ruleset_id) + + env_params = env_params.replace(ruleset=ruleset) + print_ruleset(ruleset) print() control = ManualControl(env=env, env_params=env_params, agent_view=args.agent_view) From 845eac9332685f4279bc8f6bd294875689f4c020 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Tue, 5 Mar 2024 19:11:37 +0300 Subject: [PATCH 08/14] manual control video save --- pyproject.toml | 4 ++-- src/xminigrid/manual_control.py | 42 +++++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0f1db0..99dceb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ dependencies = [ "flax>=0.8.0", "rich>=13.4.2", "chex>=0.1.85", + "imageio>=2.31.2", + "imageio-ffmpeg>=0.4.9", ] [project.optional-dependencies] @@ -49,8 +51,6 @@ dev = [ baselines = [ "matplotlib>=3.7.2", - "imageio>=2.31.2", - "imageio-ffmpeg>=0.4.9", "wandb>=0.15.10", "pyrallis>=0.3.1", "distrax>=0.1.4", diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 95185f5..5ea4068 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -1,5 +1,9 @@ +from __future__ import annotations + import argparse +import os +import imageio import jax import numpy as np import pygame @@ -9,17 +13,26 @@ import xminigrid from .environment import Environment, EnvParamsT -from .envs.xland import XLandEnvParams from .rendering.text_render import print_ruleset from .types import EnvCarryT -from .wrappers import GymAutoResetWrapper class ManualControl: - def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParamsT, agent_view: bool = False): + def __init__( + self, + env: Environment[EnvParamsT, EnvCarryT], + env_params: EnvParamsT, + agent_view: bool = False, + save_video: bool = False, + video_path: str | None = None, + ): self.env = env self.env_params = env_params self.agent_view = agent_view + self.save_video = save_video + self.video_path = video_path + if self.save_video: + self.frames = [] self._reset = jax.jit(self.env.reset) self._step = jax.jit(self.env.step) @@ -40,6 +53,10 @@ def render(self) -> None: img = self.timestep.observation else: img = self.env.render(self.env_params, self.timestep) + + if self.save_video: + self.frames.append(img) + # [h, w, c] -> [w, h, c] img = np.transpose(img, axes=(1, 0, 2)) @@ -103,6 +120,9 @@ def step(self, action: int) -> None: ) self.render() + if self.timestep.last(): + self.reset() + def reset(self) -> None: print("Reset!") self._key, reset_key = jax.random.split(self._key) @@ -144,6 +164,11 @@ def close(self) -> None: if self.window: pygame.quit() + if self.save_video: + assert self.video_path is not None + save_path = os.path.join(self.video_path, "manual_control_rollout.mp4") + imageio.mimsave(save_path, self.frames, fps=8, format="mp4") + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -151,10 +176,11 @@ def close(self) -> None: parser.add_argument("--benchmark-id", type=str, default="trivial-1m", choices=xminigrid.registered_benchmarks()) parser.add_argument("--ruleset-id", type=int, default=0) parser.add_argument("--agent-view", action="store_true") + parser.add_argument("--save-video", action="store_true") + parser.add_argument("--video-path", type=str, default=".") args = parser.parse_args() env, env_params = xminigrid.make(args.env_id) - env = GymAutoResetWrapper(env) if args.agent_view: from xminigrid.experimental.img_obs import RGBImgObservationWrapper @@ -169,5 +195,11 @@ def close(self) -> None: print_ruleset(ruleset) print() - control = ManualControl(env=env, env_params=env_params, agent_view=args.agent_view) + control = ManualControl( + env=env, + env_params=env_params, + agent_view=args.agent_view, + save_video=args.save_video, + video_path=args.video_path, + ) control.start() From 27c5c726ef6a24f320c8418e2046918e7264ca6f Mon Sep 17 00:00:00 2001 From: Howuhh Date: Tue, 5 Mar 2024 19:37:48 +0300 Subject: [PATCH 09/14] adapted training scripts for img obs --- training/nn.py | 39 ++++++++++++++++++++++++----------- training/train_meta_task.py | 8 +++++++ training/train_single_task.py | 8 +++++++ 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/training/nn.py b/training/nn.py index a067df6..99ebac1 100644 --- a/training/nn.py +++ b/training/nn.py @@ -100,23 +100,38 @@ class ActorCriticRNN(nn.Module): rnn_hidden_dim: int = 64 rnn_num_layers: int = 1 head_hidden_dim: int = 64 + img_obs: bool = False @nn.compact def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]: B, S = inputs["observation"].shape[:2] # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py - img_encoder = nn.Sequential( - [ - nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), - nn.relu, - # use this only for image sizes >= 7 - # MaxPool2d((2, 2)), - nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), - nn.relu, - nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), - nn.relu, - ] - ) + if self.img_obs: + # slight modification of NatureDQN CNN + img_encoder = nn.Sequential( + [ + nn.Conv(32, (8, 8), strides=4, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(64, (4, 4), strides=3, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(64, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(64, (2, 2), strides=1, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + ] + ) + else: + img_encoder = nn.Sequential( + [ + nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + # use this only for image sizes >= 7 + # MaxPool2d((2, 2)), + nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + ] + ) action_encoder = nn.Embed(self.num_actions, self.action_emb_dim) rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers) diff --git a/training/train_meta_task.py b/training/train_meta_task.py index 2c0f2db..e491fcc 100644 --- a/training/train_meta_task.py +++ b/training/train_meta_task.py @@ -36,6 +36,7 @@ class TrainConfig: name: str = "meta-task-ppo" env_id: str = "XLand-MiniGrid-R1-9x9" benchmark_id: str = "trivial-1m" + img_obs: bool = False # agent action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 @@ -90,6 +91,12 @@ def linear_schedule(count): env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + # enabling image observations if needed + if config.img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # loading benchmark benchmark = xminigrid.load_benchmark(config.benchmark_id) @@ -103,6 +110,7 @@ def linear_schedule(count): rnn_hidden_dim=config.rnn_hidden_dim, rnn_num_layers=config.rnn_num_layers, head_hidden_dim=config.head_hidden_dim, + img_obs=config.img_obs, ) # [batch_size, seq_len, ...] init_obs = { diff --git a/training/train_single_task.py b/training/train_single_task.py index 0c63f49..e4064a9 100644 --- a/training/train_single_task.py +++ b/training/train_single_task.py @@ -29,6 +29,7 @@ class TrainConfig: group: str = "default" name: str = "single-task-ppo" env_id: str = "MiniGrid-Empty-6x6" + img_obs: bool = False # agent action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 @@ -74,6 +75,12 @@ def linear_schedule(count): env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + # enabling image observations if needed + if config.img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # setup training state rng = jax.random.PRNGKey(config.seed) rng, _rng = jax.random.split(rng) @@ -84,6 +91,7 @@ def linear_schedule(count): rnn_hidden_dim=config.rnn_hidden_dim, rnn_num_layers=config.rnn_num_layers, head_hidden_dim=config.head_hidden_dim, + img_obs=config.img_obs, ) # [batch_size, seq_len, ...] init_obs = { From 9b5d6df9cf4964e8bf395c4795414f9d96c0964b Mon Sep 17 00:00:00 2001 From: Howuhh Date: Tue, 5 Mar 2024 20:02:04 +0300 Subject: [PATCH 10/14] fix imageio --- src/xminigrid/experimental/img_obs.py | 2 +- src/xminigrid/manual_control.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index a90f315..759bf0a 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -54,7 +54,7 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np. TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) - print("Done. Cache will be reused on consequent runs.") + print(f"Done. Cache is saved to {cache_path} and will be reused on consequent runs.") save_bz2_pickle({"tile_cache": TILE_CACHE, "tile_agent_cache": TILE_W_AGENT_CACHE}, cache_path) TILE_CACHE = load_bz2_pickle(cache_path)["tile_cache"] diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 5ea4068..d4174f4 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -3,7 +3,7 @@ import argparse import os -import imageio +import imageio.v3 as iio import jax import numpy as np import pygame @@ -167,7 +167,7 @@ def close(self) -> None: if self.save_video: assert self.video_path is not None save_path = os.path.join(self.video_path, "manual_control_rollout.mp4") - imageio.mimsave(save_path, self.frames, fps=8, format="mp4") + iio.imwrite(save_path, self.frames, format_hint=".mp4", fps=8) if __name__ == "__main__": From c783db6de0be61d1e3dd32b01ef06609dde563bc Mon Sep 17 00:00:00 2001 From: Howuhh Date: Tue, 5 Mar 2024 22:47:56 +0300 Subject: [PATCH 11/14] fix bug --- src/xminigrid/benchmarks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xminigrid/benchmarks.py b/src/xminigrid/benchmarks.py index 2120d76..2b23790 100644 --- a/src/xminigrid/benchmarks.py +++ b/src/xminigrid/benchmarks.py @@ -20,8 +20,8 @@ NAME2HFFILENAME = { "trivial-1m": "trivial_1m_v2", "small-1m": "small_1m_v2", - "medium-1m": "medium_1m_v1_v2", - "medium-3m": "medium_3m_v1_v2", + "medium-1m": "medium_1m_v2", + "medium-3m": "medium_3m_v2", "high-1m": "high_1m_v2", "high-3m": "high_3m_v2", } From 21b486fca9e7cf2a148ef59daab752f89b29fc4a Mon Sep 17 00:00:00 2001 From: Howuhh Date: Wed, 6 Mar 2024 14:15:13 +0300 Subject: [PATCH 12/14] add force cache reload, fix bug in wrapper --- src/xminigrid/experimental/img_obs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index 759bf0a..cc854b4 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -14,6 +14,7 @@ from ..wrappers import Wrapper CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid")) +FORCE_RELOAD = os.environ.get("XLAND_MINIGRID_RELOAD_CACHE", False) def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.ndarray]: @@ -47,7 +48,8 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np. TILE_SIZE = 32 cache_path = os.path.join(CACHE_PATH, "render_cache") -if not os.path.exists(cache_path): + +if not os.path.exists(cache_path) or FORCE_RELOAD: os.makedirs(CACHE_PATH, exist_ok=True) print("Building rendering cache, may take a while...") TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE) @@ -80,7 +82,7 @@ def _render_obs(obs: jax.Array) -> jax.Array: class RGBImgObservationWrapper(Wrapper): def observation_shape(self, params): - return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, NUM_LAYERS + return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3 def reset(self, params, key): timestep = self._env.reset(params, key) From dc26721ec7b30dd9c4a4a7c6ee3fb3c7fcf69d32 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Wed, 6 Mar 2024 21:03:36 +0300 Subject: [PATCH 13/14] fix bug with rules and empty tiles --- src/xminigrid/core/rules.py | 25 +++++++++++++------------ src/xminigrid/manual_control.py | 18 ++++++++++++++++-- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/xminigrid/core/rules.py b/src/xminigrid/core/rules.py index 5f903b0..83a6b89 100644 --- a/src/xminigrid/core/rules.py +++ b/src/xminigrid/core/rules.py @@ -202,18 +202,18 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] y, x = position up, _, down, _ = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(down, self.tile_a), - grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(up, self.tile_b), - grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid @@ -243,18 +243,19 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] + y, x = position _, right, _, left = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(left, self.tile_a), - grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(right, self.tile_b), - grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid @@ -284,18 +285,18 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] y, x = position up, _, down, _ = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(up, self.tile_a), - grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(down, self.tile_b), - grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid @@ -325,18 +326,18 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] y, x = position _, right, _, left = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(right, self.tile_a), - grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(left, self.tile_b), - grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index d4174f4..2ea07c4 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -25,12 +25,17 @@ def __init__( agent_view: bool = False, save_video: bool = False, video_path: str | None = None, + video_format: str = ".mp4", + video_fps: int = 8, ): self.env = env self.env_params = env_params self.agent_view = agent_view self.save_video = save_video self.video_path = video_path + self.video_format = video_format + self.video_fps = video_fps + if self.save_video: self.frames = [] @@ -166,8 +171,13 @@ def close(self) -> None: if self.save_video: assert self.video_path is not None - save_path = os.path.join(self.video_path, "manual_control_rollout.mp4") - iio.imwrite(save_path, self.frames, format_hint=".mp4", fps=8) + save_path = os.path.join(self.video_path, f"manual_control_rollout{self.video_format}") + if self.video_format == ".mp4": + iio.imwrite(save_path, self.frames, format_hint=".mp4", fps=self.video_fps) + elif self.video_format == ".gif": + iio.imwrite(save_path, self.frames, format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10) + else: + raise RuntimeError("Unknown video format! Should be one of ('.mp4', '.gif')") if __name__ == "__main__": @@ -178,6 +188,8 @@ def close(self) -> None: parser.add_argument("--agent-view", action="store_true") parser.add_argument("--save-video", action="store_true") parser.add_argument("--video-path", type=str, default=".") + parser.add_argument("--video-format", type=str, default=".mp4", choices=(".mp4", ".gif")) + parser.add_argument("--video-fps", type=int, default=8) args = parser.parse_args() env, env_params = xminigrid.make(args.env_id) @@ -201,5 +213,7 @@ def close(self) -> None: agent_view=args.agent_view, save_video=args.save_video, video_path=args.video_path, + video_format=args.video_format, + video_fps=args.video_fps, ) control.start() From 251301facdffcf5de4590bffc2f4f5431c16370d Mon Sep 17 00:00:00 2001 From: Howuhh Date: Sun, 24 Mar 2024 13:50:09 +0300 Subject: [PATCH 14/14] refine readme, cnn arch --- README.md | 4 ++++ src/xminigrid/manual_control.py | 7 +++++-- training/nn.py | 9 ++++----- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index cddc57a..6d67451 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ On the high level, current API combines [dm_env](https://github.com/google-deepm import jax import xminigrid from xminigrid.wrappers import GymAutoResetWrapper +from xminigrid.experimental.img_obs import RGBImgObservationWrapper key = jax.random.PRNGKey(0) reset_key, ruleset_key = jax.random.split(key) @@ -109,6 +110,9 @@ env_params = env_params.replace(ruleset=ruleset) # auto-reset wrapper env = GymAutoResetWrapper(env) +# render obs as rgb images if needed (warn: this will affect speed greatly) +env = RGBImgObservationWrapper(env) + # fully jit-compatible step and reset methods timestep = jax.jit(env.reset)(env_params, reset_key) timestep = jax.jit(env.step)(env_params, timestep, action=0) diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 2ea07c4..43f9e7b 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -175,7 +175,10 @@ def close(self) -> None: if self.video_format == ".mp4": iio.imwrite(save_path, self.frames, format_hint=".mp4", fps=self.video_fps) elif self.video_format == ".gif": - iio.imwrite(save_path, self.frames, format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10) + iio.imwrite( + save_path, self.frames[:-1], format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10 + ) + # iio.imwrite(save_path, self.frames, format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10) else: raise RuntimeError("Unknown video format! Should be one of ('.mp4', '.gif')") @@ -189,7 +192,7 @@ def close(self) -> None: parser.add_argument("--save-video", action="store_true") parser.add_argument("--video-path", type=str, default=".") parser.add_argument("--video-format", type=str, default=".mp4", choices=(".mp4", ".gif")) - parser.add_argument("--video-fps", type=int, default=8) + parser.add_argument("--video-fps", type=int, default=5) args = parser.parse_args() env, env_params = xminigrid.make(args.env_id) diff --git a/training/nn.py b/training/nn.py index 99ebac1..759c579 100644 --- a/training/nn.py +++ b/training/nn.py @@ -107,16 +107,15 @@ def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax B, S = inputs["observation"].shape[:2] # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py if self.img_obs: - # slight modification of NatureDQN CNN img_encoder = nn.Sequential( [ - nn.Conv(32, (8, 8), strides=4, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv(16, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), nn.relu, - nn.Conv(64, (4, 4), strides=3, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), nn.relu, - nn.Conv(64, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), nn.relu, - nn.Conv(64, (2, 2), strides=1, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), ] ) else: