diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 308b33e..057a42b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,4 @@ repos: rev: v1.1.350 hooks: - id: pyright -# args: [--project=pyproject.toml] \ No newline at end of file + args: [--project=pyproject.toml] \ No newline at end of file diff --git a/README.md b/README.md index cddc57a..6d67451 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ On the high level, current API combines [dm_env](https://github.com/google-deepm import jax import xminigrid from xminigrid.wrappers import GymAutoResetWrapper +from xminigrid.experimental.img_obs import RGBImgObservationWrapper key = jax.random.PRNGKey(0) reset_key, ruleset_key = jax.random.split(key) @@ -109,6 +110,9 @@ env_params = env_params.replace(ruleset=ruleset) # auto-reset wrapper env = GymAutoResetWrapper(env) +# render obs as rgb images if needed (warn: this will affect speed greatly) +env = RGBImgObservationWrapper(env) + # fully jit-compatible step and reset methods timestep = jax.jit(env.reset)(env_params, reset_key) timestep = jax.jit(env.step)(env_params, timestep, action=0) diff --git a/pyproject.toml b/pyproject.toml index d0f1db0..99dceb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ dependencies = [ "flax>=0.8.0", "rich>=13.4.2", "chex>=0.1.85", + "imageio>=2.31.2", + "imageio-ffmpeg>=0.4.9", ] [project.optional-dependencies] @@ -49,8 +51,6 @@ dev = [ baselines = [ "matplotlib>=3.7.2", - "imageio>=2.31.2", - "imageio-ffmpeg>=0.4.9", "wandb>=0.15.10", "pyrallis>=0.3.1", "distrax>=0.1.4", diff --git a/scripts/benchmark_xland.py b/scripts/benchmark_xland.py index a7bd68b..5af3742 100644 --- a/scripts/benchmark_xland.py +++ b/scripts/benchmark_xland.py @@ -15,15 +15,29 @@ parser = argparse.ArgumentParser() parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-16x16") parser.add_argument("--benchmark-id", type=str, default="Trivial") +parser.add_argument("--img-obs", action="store_true") parser.add_argument("--timesteps", type=int, default=1000) parser.add_argument("--num-envs", type=int, default=8192) parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats") parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)") -def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None): +def build_benchmark( + env_id: str, + num_envs: int, + timesteps: int, + benchmark_id: Optional[str] = None, + img_obs: bool = False, +): env, env_params = xminigrid.make(env_id) env = GymAutoResetWrapper(env) + + # enable img observations if needed + if img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # choose XLand benchmark if needed if "XLand-MiniGrid" in env_id and benchmark_id is not None: ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) @@ -73,13 +87,15 @@ def timeit_benchmark(args, benchmark_fn): print("Num devices for pmap:", num_devices) # building for single env benchmarking - benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id) + benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id, args.img_obs) benchmark_fn_single = jax.jit(benchmark_fn_single) # building vmap for vectorization benchmarking - benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id) + benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id, args.img_obs) benchmark_fn_vmap = jax.jit(benchmark_fn_vmap) # building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps) - benchmark_fn_pmap = build_benchmark(args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id) + benchmark_fn_pmap = build_benchmark( + args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs + ) benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) key = jax.random.PRNGKey(0) diff --git a/scripts/benchmark_xland_all.py b/scripts/benchmark_xland_all.py index 3474820..cae74f9 100644 --- a/scripts/benchmark_xland_all.py +++ b/scripts/benchmark_xland_all.py @@ -18,14 +18,24 @@ parser = argparse.ArgumentParser() parser.add_argument("--benchmark-id", type=str, default="trivial-1m") +parser.add_argument("--img-obs", action="store_true") parser.add_argument("--timesteps", type=int, default=1000) parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats") parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)") -def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None): +def build_benchmark( + env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None, img_obs: bool = False +): env, env_params = xminigrid.make(env_id) env = GymAutoResetWrapper(env) + + # enable img observations if needed + if img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # choose XLand benchmark if needed if "XLand-MiniGrid" in env_id and benchmark_id is not None: ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) @@ -77,7 +87,9 @@ def timeit_benchmark(args, benchmark_fn): for env_id in tqdm(environments, desc="Envs.."): assert num_envs % num_devices == 0 # building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps) - benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id) + benchmark_fn_pmap = build_benchmark( + env_id, num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs + ) benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) # benchmarking diff --git a/scripts/generate_benchmarks.sh b/scripts/generate_benchmarks.sh index 9593efb..47e9528 100644 --- a/scripts/generate_benchmarks.sh +++ b/scripts/generate_benchmarks.sh @@ -7,7 +7,6 @@ python scripts/ruleset_generator.py \ --total_rulesets=1_000_000 \ --save_path="trivial_1m" - # small python scripts/ruleset_generator.py \ --prune_chain \ @@ -41,17 +40,16 @@ python scripts/ruleset_generator.py \ --total_rulesets=1_000_000 \ --save_path="high_1m" - -# medium + distractors -python scripts/ruleset_generator.py \ - --prune_chain \ - --prune_prob=0.8 \ - --chain_depth=2 \ - --sample_distractor_rules \ - --num_distractor_rules=4 \ - --num_distractor_objects=2 \ - --total_rulesets=1_000_000 \ - --save_path="medium_dist_1m" +## medium + distractors +#python scripts/ruleset_generator.py \ +# --prune_chain \ +# --prune_prob=0.8 \ +# --chain_depth=2 \ +# --sample_distractor_rules \ +# --num_distractor_rules=4 \ +# --num_distractor_objects=2 \ +# --total_rulesets=1_000_000 \ +# --save_path="medium_dist_1m" # medium 3M python scripts/ruleset_generator.py \ diff --git a/src/xminigrid/__init__.py b/src/xminigrid/__init__.py index 2e0abaa..a2f2af1 100644 --- a/src/xminigrid/__init__.py +++ b/src/xminigrid/__init__.py @@ -2,7 +2,7 @@ from .registration import make, register, registered_environments # TODO: add __all__ -__version__ = "0.6.0" +__version__ = "0.7.0" # ---------- XLand-MiniGrid environments ---------- @@ -210,7 +210,8 @@ # BlockedUnlockPickUp register( - id="MiniGrid-BlockedUnlockPickUp", entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp" + id="MiniGrid-BlockedUnlockPickUp", + entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp", ) # DoorKey diff --git a/src/xminigrid/benchmarks.py b/src/xminigrid/benchmarks.py index 40b0641..2b23790 100644 --- a/src/xminigrid/benchmarks.py +++ b/src/xminigrid/benchmarks.py @@ -18,13 +18,12 @@ DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid")) NAME2HFFILENAME = { - "trivial-1m": "trivial_1m", - "small-1m": "small_1m", - "small-dist-1m": "small_dist_1m", - "medium-1m": "medium_1m_v1", - "medium-3m": "medium_3m_v1", - "high-1m": "high_1m", - "high-3m": "high_3m", + "trivial-1m": "trivial_1m_v2", + "small-1m": "small_1m_v2", + "medium-1m": "medium_1m_v2", + "medium-3m": "medium_3m_v2", + "high-1m": "high_1m_v2", + "high-3m": "high_3m_v2", } diff --git a/src/xminigrid/core/constants.py b/src/xminigrid/core/constants.py index 1d36d73..f9b9241 100644 --- a/src/xminigrid/core/constants.py +++ b/src/xminigrid/core/constants.py @@ -1,48 +1,44 @@ import jax.numpy as jnp from flax import struct +NUM_ACTIONS = 6 + # GRID: [tile, color] NUM_LAYERS = 2 -NUM_TILES = 15 -NUM_COLORS = 14 -NUM_ACTIONS = 6 +NUM_TILES = 13 +NUM_COLORS = 12 -# TODO: do we really need END_OF_MAP? seem like unseen can be used instead... # enums, kinda... class Tiles(struct.PyTreeNode): EMPTY: int = struct.field(pytree_node=False, default=0) - END_OF_MAP: int = struct.field(pytree_node=False, default=1) - UNSEEN: int = struct.field(pytree_node=False, default=2) - FLOOR: int = struct.field(pytree_node=False, default=3) - WALL: int = struct.field(pytree_node=False, default=4) - BALL: int = struct.field(pytree_node=False, default=5) - SQUARE: int = struct.field(pytree_node=False, default=6) - PYRAMID: int = struct.field(pytree_node=False, default=7) - GOAL: int = struct.field(pytree_node=False, default=8) - KEY: int = struct.field(pytree_node=False, default=9) - DOOR_LOCKED: int = struct.field(pytree_node=False, default=10) - DOOR_CLOSED: int = struct.field(pytree_node=False, default=11) - DOOR_OPEN: int = struct.field(pytree_node=False, default=12) - HEX: int = struct.field(pytree_node=False, default=13) - STAR: int = struct.field(pytree_node=False, default=14) + FLOOR: int = struct.field(pytree_node=False, default=1) + WALL: int = struct.field(pytree_node=False, default=2) + BALL: int = struct.field(pytree_node=False, default=3) + SQUARE: int = struct.field(pytree_node=False, default=4) + PYRAMID: int = struct.field(pytree_node=False, default=5) + GOAL: int = struct.field(pytree_node=False, default=6) + KEY: int = struct.field(pytree_node=False, default=7) + DOOR_LOCKED: int = struct.field(pytree_node=False, default=8) + DOOR_CLOSED: int = struct.field(pytree_node=False, default=9) + DOOR_OPEN: int = struct.field(pytree_node=False, default=10) + HEX: int = struct.field(pytree_node=False, default=11) + STAR: int = struct.field(pytree_node=False, default=12) class Colors(struct.PyTreeNode): EMPTY: int = struct.field(pytree_node=False, default=0) - END_OF_MAP: int = struct.field(pytree_node=False, default=1) - UNSEEN: int = struct.field(pytree_node=False, default=2) - RED: int = struct.field(pytree_node=False, default=3) - GREEN: int = struct.field(pytree_node=False, default=4) - BLUE: int = struct.field(pytree_node=False, default=5) - PURPLE: int = struct.field(pytree_node=False, default=6) - YELLOW: int = struct.field(pytree_node=False, default=7) - GREY: int = struct.field(pytree_node=False, default=8) - BLACK: int = struct.field(pytree_node=False, default=9) - ORANGE: int = struct.field(pytree_node=False, default=10) - WHITE: int = struct.field(pytree_node=False, default=11) - BROWN: int = struct.field(pytree_node=False, default=12) - PINK: int = struct.field(pytree_node=False, default=13) + RED: int = struct.field(pytree_node=False, default=1) + GREEN: int = struct.field(pytree_node=False, default=2) + BLUE: int = struct.field(pytree_node=False, default=3) + PURPLE: int = struct.field(pytree_node=False, default=4) + YELLOW: int = struct.field(pytree_node=False, default=5) + GREY: int = struct.field(pytree_node=False, default=6) + BLACK: int = struct.field(pytree_node=False, default=7) + ORANGE: int = struct.field(pytree_node=False, default=8) + WHITE: int = struct.field(pytree_node=False, default=9) + BROWN: int = struct.field(pytree_node=False, default=10) + PINK: int = struct.field(pytree_node=False, default=11) # Only ~100 combinations so far, better to preallocate them @@ -65,7 +61,6 @@ class Colors(struct.PyTreeNode): WALKABLE = jnp.array( ( - Tiles.EMPTY, Tiles.FLOOR, Tiles.GOAL, Tiles.DOOR_OPEN, @@ -83,12 +78,7 @@ class Colors(struct.PyTreeNode): ) ) -FREE_TO_PUT_DOWN = jnp.array( - ( - Tiles.EMPTY, - Tiles.FLOOR, - ) -) +FREE_TO_PUT_DOWN = jnp.array((Tiles.FLOOR,)) LOS_BLOCKING = jnp.array( ( diff --git a/src/xminigrid/core/grid.py b/src/xminigrid/core/grid.py index 803b6c7..b3152bd 100644 --- a/src/xminigrid/core/grid.py +++ b/src/xminigrid/core/grid.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp @@ -22,7 +22,7 @@ def equal(tile1: Tile, tile2: Tile) -> Tile: def get_neighbouring_tiles(grid: GridState, y: IntOrArray, x: IntOrArray) -> tuple[Tile, Tile, Tile, Tile]: # end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP] - end_of_map = Tiles.END_OF_MAP + end_of_map = Tiles.EMPTY up_tile = grid.at[y - 1, x].get(mode="fill", fill_value=end_of_map) right_tile = grid.at[y, x + 1].get(mode="fill", fill_value=end_of_map) diff --git a/src/xminigrid/core/observation.py b/src/xminigrid/core/observation.py index e2c4ab8..e8ea031 100644 --- a/src/xminigrid/core/observation.py +++ b/src/xminigrid/core/observation.py @@ -12,7 +12,7 @@ def crop_field_of_view(grid: GridState, agent: AgentState, height: int, width: i grid = jnp.pad( grid, pad_width=((height, height), (width, width), (0, 0)), - constant_values=Tiles.END_OF_MAP, + constant_values=Tiles.EMPTY, ) # account for padding y = agent.position[0] + height @@ -110,8 +110,8 @@ def minigrid_field_of_view(grid: GridState, agent: AgentState, height: int, widt fov_grid = crop_field_of_view(grid, agent, height, width) fov_grid = align_with_up(fov_grid, agent.direction) mask = generate_viz_mask_minigrid(fov_grid) - # set UNSEEN value for all layers (including colors, as UNSEEN color has same id value) - fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.UNSEEN) + # set EMPTY as unseen value for all layers (including colors, as EMPTY color has same id value) + fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.EMPTY) # TODO: should we even do this? Agent with good memory can remember what he picked up. # WARN: this can overwrite tile the agent is on, GOAL for example. diff --git a/src/xminigrid/core/rules.py b/src/xminigrid/core/rules.py index 5f903b0..83a6b89 100644 --- a/src/xminigrid/core/rules.py +++ b/src/xminigrid/core/rules.py @@ -202,18 +202,18 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] y, x = position up, _, down, _ = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(down, self.tile_a), - grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(up, self.tile_b), - grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid @@ -243,18 +243,19 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] + y, x = position _, right, _, left = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(left, self.tile_a), - grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(right, self.tile_b), - grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid @@ -284,18 +285,18 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] y, x = position up, _, down, _ = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(up, self.tile_a), - grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(down, self.tile_b), - grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid @@ -325,18 +326,18 @@ def __call__(self, grid, agent, action, position): tile = grid[position[0], position[1]] def _rule_fn(grid): - empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + floor_tile = TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK] y, x = position _, right, _, left = get_neighbouring_tiles(grid, y, x) grid = jax.lax.select( equal(tile, self.tile_b) & equal(right, self.tile_a), - grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) grid = jax.lax.select( equal(tile, self.tile_a) & equal(left, self.tile_b), - grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(floor_tile), grid, ) return grid diff --git a/src/xminigrid/envs/xland.py b/src/xminigrid/envs/xland.py index 4acac15..9092005 100644 --- a/src/xminigrid/envs/xland.py +++ b/src/xminigrid/envs/xland.py @@ -27,7 +27,6 @@ init_tiles=jnp.array(((TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY],))), ) -_empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] _wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] # colors for doors between rooms _allowed_colors = jnp.array( diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py new file mode 100644 index 0000000..cc854b4 --- /dev/null +++ b/src/xminigrid/experimental/img_obs.py @@ -0,0 +1,95 @@ +# jit-compatible RGB observations. Currently experimental! +# if it proves useful and necessary in the future, I will consider rewriting env.render in such style also +from __future__ import annotations + +import os + +import jax +import jax.numpy as jnp +import numpy as np + +from ..benchmarks import load_bz2_pickle, save_bz2_pickle +from ..core.constants import NUM_COLORS, NUM_LAYERS, TILES_REGISTRY +from ..rendering.rgb_render import render_tile +from ..wrappers import Wrapper + +CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid")) +FORCE_RELOAD = os.environ.get("XLAND_MINIGRID_RELOAD_CACHE", False) + + +def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.ndarray]: + cache = np.zeros((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8) + agent_cache = np.zeros((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8) + + for y in range(tiles.shape[0]): + for x in range(tiles.shape[1]): + # rendering tile + tile_img = render_tile( + tile=tuple(tiles[y, x]), + agent_direction=None, + highlight=False, + tile_size=int(tile_size), + ) + cache[y, x] = tile_img + + # rendering agent on top + tile_w_agent_img = render_tile( + tile=tuple(tiles[y, x]), + agent_direction=0, + highlight=False, + tile_size=int(tile_size), + ) + agent_cache[y, x] = tile_w_agent_img + + return cache, agent_cache + + +# building cache of pre-rendered tiles +TILE_SIZE = 32 + +cache_path = os.path.join(CACHE_PATH, "render_cache") + +if not os.path.exists(cache_path) or FORCE_RELOAD: + os.makedirs(CACHE_PATH, exist_ok=True) + print("Building rendering cache, may take a while...") + TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE) + TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) + TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3) + + print(f"Done. Cache is saved to {cache_path} and will be reused on consequent runs.") + save_bz2_pickle({"tile_cache": TILE_CACHE, "tile_agent_cache": TILE_W_AGENT_CACHE}, cache_path) + +TILE_CACHE = load_bz2_pickle(cache_path)["tile_cache"] +TILE_W_AGENT_CACHE = load_bz2_pickle(cache_path)["tile_agent_cache"] + + +# rendering with cached tiles +def _render_obs(obs: jax.Array) -> jax.Array: + view_size = obs.shape[0] + + obs_flat_idxs = obs[:, :, 0] * NUM_COLORS + obs[:, :, 1] + # render all tiles + rendered_obs = jnp.take(TILE_CACHE, obs_flat_idxs, axis=0) + + # add agent tile + agent_tile = TILE_W_AGENT_CACHE[obs_flat_idxs[view_size - 1, view_size // 2]] + rendered_obs = rendered_obs.at[view_size - 1, view_size // 2].set(agent_tile) + # [view_size, view_size, tile_size, tile_size, 3] -> [view_size * tile_size, view_size * tile_size, 3] + rendered_obs = rendered_obs.transpose((0, 2, 1, 3, 4)).reshape(view_size * TILE_SIZE, view_size * TILE_SIZE, 3) + + return rendered_obs + + +class RGBImgObservationWrapper(Wrapper): + def observation_shape(self, params): + return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3 + + def reset(self, params, key): + timestep = self._env.reset(params, key) + timestep = timestep.replace(observation=_render_obs(timestep.observation)) + return timestep + + def step(self, params, timestep, action): + timestep = self._env.step(params, timestep, action) + timestep = timestep.replace(observation=_render_obs(timestep.observation)) + return timestep diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 1dcd31f..43f9e7b 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -1,5 +1,9 @@ +from __future__ import annotations + import argparse +import os +import imageio.v3 as iio import jax import numpy as np import pygame @@ -7,16 +11,33 @@ from pygame.event import Event import xminigrid -from xminigrid.wrappers import GymAutoResetWrapper from .environment import Environment, EnvParamsT +from .rendering.text_render import print_ruleset from .types import EnvCarryT class ManualControl: - def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParamsT): + def __init__( + self, + env: Environment[EnvParamsT, EnvCarryT], + env_params: EnvParamsT, + agent_view: bool = False, + save_video: bool = False, + video_path: str | None = None, + video_format: str = ".mp4", + video_fps: int = 8, + ): self.env = env self.env_params = env_params + self.agent_view = agent_view + self.save_video = save_video + self.video_path = video_path + self.video_format = video_format + self.video_fps = video_fps + + if self.save_video: + self.frames = [] self._reset = jax.jit(self.env.reset) self._step = jax.jit(self.env.step) @@ -33,7 +54,14 @@ def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParam def render(self) -> None: assert self.timestep is not None - img = self.env.render(self.env_params, self.timestep) + if self.agent_view: + img = self.timestep.observation + else: + img = self.env.render(self.env_params, self.timestep) + + if self.save_video: + self.frames.append(img) + # [h, w, c] -> [w, h, c] img = np.transpose(img, axes=(1, 0, 2)) @@ -97,6 +125,9 @@ def step(self, action: int) -> None: ) self.render() + if self.timestep.last(): + self.reset() + def reset(self) -> None: print("Reset!") self._key, reset_key = jax.random.split(self._key) @@ -138,20 +169,54 @@ def close(self) -> None: if self.window: pygame.quit() + if self.save_video: + assert self.video_path is not None + save_path = os.path.join(self.video_path, f"manual_control_rollout{self.video_format}") + if self.video_format == ".mp4": + iio.imwrite(save_path, self.frames, format_hint=".mp4", fps=self.video_fps) + elif self.video_format == ".gif": + iio.imwrite( + save_path, self.frames[:-1], format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10 + ) + # iio.imwrite(save_path, self.frames, format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10) + else: + raise RuntimeError("Unknown video format! Should be one of ('.mp4', '.gif')") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-5x5", choices=xminigrid.registered_environments()) parser.add_argument("--benchmark-id", type=str, default="trivial-1m", choices=xminigrid.registered_benchmarks()) parser.add_argument("--ruleset-id", type=int, default=0) + parser.add_argument("--agent-view", action="store_true") + parser.add_argument("--save-video", action="store_true") + parser.add_argument("--video-path", type=str, default=".") + parser.add_argument("--video-format", type=str, default=".mp4", choices=(".mp4", ".gif")) + parser.add_argument("--video-fps", type=int, default=5) args = parser.parse_args() env, env_params = xminigrid.make(args.env_id) - env = GymAutoResetWrapper(env) + + if args.agent_view: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) if "XLand" in args.env_id: bench = xminigrid.load_benchmark(args.benchmark_id) - env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id)) - - control = ManualControl(env=env, env_params=env_params) + ruleset = bench.get_ruleset(args.ruleset_id) + + env_params = env_params.replace(ruleset=ruleset) + print_ruleset(ruleset) + print() + + control = ManualControl( + env=env, + env_params=env_params, + agent_view=args.agent_view, + save_video=args.save_video, + video_path=args.video_path, + video_format=args.video_format, + video_fps=args.video_fps, + ) control.start() diff --git a/src/xminigrid/rendering/rgb_render.py b/src/xminigrid/rendering/rgb_render.py index 5c8ea81..0343a93 100644 --- a/src/xminigrid/rendering/rgb_render.py +++ b/src/xminigrid/rendering/rgb_render.py @@ -5,7 +5,7 @@ import numpy as np from ..core.constants import Colors, Tiles -from ..types import AgentState, GridState, IntOrArray +from ..types import AgentState, IntOrArray from .utils import ( downsample, fill_coords, @@ -18,6 +18,7 @@ ) COLORS_MAP = { + Colors.EMPTY: np.array((255, 255, 255)), # just a placeholder Colors.RED: np.array((255, 0, 0)), Colors.GREEN: np.array((0, 255, 0)), Colors.BLUE: np.array((0, 0, 255)), @@ -32,6 +33,16 @@ } +def _render_empty(img: np.ndarray, color: int): + fill_coords(img, point_in_rect(0.45, 0.55, 0.2, 0.65), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(0.45, 0.55, 0.7, 0.85), COLORS_MAP[Colors.RED]) + + fill_coords(img, point_in_rect(0, 0.031, 0, 1), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(0, 1, 0, 0.031), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(1 - 0.031, 1, 0, 1), COLORS_MAP[Colors.RED]) + fill_coords(img, point_in_rect(0, 1, 1 - 0.031, 1), COLORS_MAP[Colors.RED]) + + def _render_floor(img: np.ndarray, color: int): # draw the grid lines (top and left edges) fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100)) @@ -165,7 +176,7 @@ def _render_player(img: np.ndarray, direction: int): Tiles.DOOR_LOCKED: _render_door_locked, Tiles.DOOR_CLOSED: _render_door_closed, Tiles.DOOR_OPEN: _render_door_open, - Tiles.EMPTY: lambda img, color: img, + Tiles.EMPTY: _render_empty, } @@ -196,7 +207,7 @@ def get_highlight_mask(grid: np.ndarray, agent: AgentState | None, view_size: in @functools.cache def render_tile( - tile: np.ndarray, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3 + tile: tuple, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3 ) -> np.ndarray: img = np.full((tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8, fill_value=255) # draw tile diff --git a/src/xminigrid/rendering/text_render.py b/src/xminigrid/rendering/text_render.py index 8f3887c..8b7c514 100644 --- a/src/xminigrid/rendering/text_render.py +++ b/src/xminigrid/rendering/text_render.py @@ -5,8 +5,6 @@ from ..types import AgentState, RuleSet COLOR_NAMES = { - Colors.END_OF_MAP: "red", - Colors.UNSEEN: "white", Colors.EMPTY: "white", Colors.RED: "red", Colors.GREEN: "green", @@ -22,8 +20,6 @@ } TILE_STR = { - Tiles.END_OF_MAP: "!", - Tiles.UNSEEN: "?", Tiles.EMPTY: " ", Tiles.FLOOR: ".", Tiles.WALL: "☰", diff --git a/src/xminigrid/wrappers.py b/src/xminigrid/wrappers.py index 15e9810..dcfa9f1 100644 --- a/src/xminigrid/wrappers.py +++ b/src/xminigrid/wrappers.py @@ -16,6 +16,12 @@ def __init__(self, env: Environment[EnvParamsT, EnvCarryT]): def default_params(self, **kwargs) -> EnvParamsT: return self._env.default_params(**kwargs) + def num_actions(self, params: EnvParamsT) -> int: + return self._env.num_actions(params) + + def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]: + return self._env.observation_shape(params) + def time_limit(self, params: EnvParamsT) -> int: return self._env.time_limit(params) diff --git a/training/nn.py b/training/nn.py index a067df6..759c579 100644 --- a/training/nn.py +++ b/training/nn.py @@ -100,23 +100,37 @@ class ActorCriticRNN(nn.Module): rnn_hidden_dim: int = 64 rnn_num_layers: int = 1 head_hidden_dim: int = 64 + img_obs: bool = False @nn.compact def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]: B, S = inputs["observation"].shape[:2] # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py - img_encoder = nn.Sequential( - [ - nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), - nn.relu, - # use this only for image sizes >= 7 - # MaxPool2d((2, 2)), - nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), - nn.relu, - nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), - nn.relu, - ] - ) + if self.img_obs: + img_encoder = nn.Sequential( + [ + nn.Conv(16, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + ] + ) + else: + img_encoder = nn.Sequential( + [ + nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + # use this only for image sizes >= 7 + # MaxPool2d((2, 2)), + nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))), + nn.relu, + ] + ) action_encoder = nn.Embed(self.num_actions, self.action_emb_dim) rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers) diff --git a/training/train_meta_task.py b/training/train_meta_task.py index 2c0f2db..e491fcc 100644 --- a/training/train_meta_task.py +++ b/training/train_meta_task.py @@ -36,6 +36,7 @@ class TrainConfig: name: str = "meta-task-ppo" env_id: str = "XLand-MiniGrid-R1-9x9" benchmark_id: str = "trivial-1m" + img_obs: bool = False # agent action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 @@ -90,6 +91,12 @@ def linear_schedule(count): env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + # enabling image observations if needed + if config.img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # loading benchmark benchmark = xminigrid.load_benchmark(config.benchmark_id) @@ -103,6 +110,7 @@ def linear_schedule(count): rnn_hidden_dim=config.rnn_hidden_dim, rnn_num_layers=config.rnn_num_layers, head_hidden_dim=config.head_hidden_dim, + img_obs=config.img_obs, ) # [batch_size, seq_len, ...] init_obs = { diff --git a/training/train_single_task.py b/training/train_single_task.py index 0c63f49..e4064a9 100644 --- a/training/train_single_task.py +++ b/training/train_single_task.py @@ -29,6 +29,7 @@ class TrainConfig: group: str = "default" name: str = "single-task-ppo" env_id: str = "MiniGrid-Empty-6x6" + img_obs: bool = False # agent action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 @@ -74,6 +75,12 @@ def linear_schedule(count): env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + # enabling image observations if needed + if config.img_obs: + from xminigrid.experimental.img_obs import RGBImgObservationWrapper + + env = RGBImgObservationWrapper(env) + # setup training state rng = jax.random.PRNGKey(config.seed) rng, _rng = jax.random.split(rng) @@ -84,6 +91,7 @@ def linear_schedule(count): rnn_hidden_dim=config.rnn_hidden_dim, rnn_num_layers=config.rnn_num_layers, head_hidden_dim=config.head_hidden_dim, + img_obs=config.img_obs, ) # [batch_size, seq_len, ...] init_obs = {