Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RGB image observation wrapper compatible with jit #9

Merged
merged 14 commits into from
Mar 24, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ repos:
rev: v1.1.350
hooks:
- id: pyright
# args: [--project=pyproject.toml]
args: [--project=pyproject.toml]
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ On the high level, current API combines [dm_env](https://github.com/google-deepm
import jax
import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)
Expand All @@ -109,6 +110,9 @@ env_params = env_params.replace(ruleset=ruleset)
# auto-reset wrapper
env = GymAutoResetWrapper(env)

# render obs as rgb images if needed (warn: this will affect speed greatly)
env = RGBImgObservationWrapper(env)

# fully jit-compatible step and reset methods
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ dependencies = [
"flax>=0.8.0",
"rich>=13.4.2",
"chex>=0.1.85",
"imageio>=2.31.2",
"imageio-ffmpeg>=0.4.9",
]

[project.optional-dependencies]
Expand All @@ -49,8 +51,6 @@ dev = [

baselines = [
"matplotlib>=3.7.2",
"imageio>=2.31.2",
"imageio-ffmpeg>=0.4.9",
"wandb>=0.15.10",
"pyrallis>=0.3.1",
"distrax>=0.1.4",
Expand Down
24 changes: 20 additions & 4 deletions scripts/benchmark_xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,29 @@
parser = argparse.ArgumentParser()
parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-16x16")
parser.add_argument("--benchmark-id", type=str, default="Trivial")
parser.add_argument("--img-obs", action="store_true")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-envs", type=int, default=8192)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
def build_benchmark(
env_id: str,
num_envs: int,
timesteps: int,
benchmark_id: Optional[str] = None,
img_obs: bool = False,
):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)

# enable img observations if needed
if img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -73,13 +87,15 @@ def timeit_benchmark(args, benchmark_fn):
print("Num devices for pmap:", num_devices)

# building for single env benchmarking
benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id)
benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id, args.img_obs)
benchmark_fn_single = jax.jit(benchmark_fn_single)
# building vmap for vectorization benchmarking
benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id)
benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id, args.img_obs)
benchmark_fn_vmap = jax.jit(benchmark_fn_vmap)
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = build_benchmark(
args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

key = jax.random.PRNGKey(0)
Expand Down
16 changes: 14 additions & 2 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,24 @@

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark-id", type=str, default="trivial-1m")
parser.add_argument("--img-obs", action="store_true")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
def build_benchmark(
env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None, img_obs: bool = False
):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)

# enable img observations if needed
if img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -77,7 +87,9 @@ def timeit_benchmark(args, benchmark_fn):
for env_id in tqdm(environments, desc="Envs.."):
assert num_envs % num_devices == 0
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = build_benchmark(
env_id, num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

# benchmarking
Expand Down
22 changes: 10 additions & 12 deletions scripts/generate_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python scripts/ruleset_generator.py \
--total_rulesets=1_000_000 \
--save_path="trivial_1m"


# small
python scripts/ruleset_generator.py \
--prune_chain \
Expand Down Expand Up @@ -41,17 +40,16 @@ python scripts/ruleset_generator.py \
--total_rulesets=1_000_000 \
--save_path="high_1m"


# medium + distractors
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.8 \
--chain_depth=2 \
--sample_distractor_rules \
--num_distractor_rules=4 \
--num_distractor_objects=2 \
--total_rulesets=1_000_000 \
--save_path="medium_dist_1m"
## medium + distractors
#python scripts/ruleset_generator.py \
# --prune_chain \
# --prune_prob=0.8 \
# --chain_depth=2 \
# --sample_distractor_rules \
# --num_distractor_rules=4 \
# --num_distractor_objects=2 \
# --total_rulesets=1_000_000 \
# --save_path="medium_dist_1m"

# medium 3M
python scripts/ruleset_generator.py \
Expand Down
5 changes: 3 additions & 2 deletions src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.6.0"
__version__ = "0.7.0"

# ---------- XLand-MiniGrid environments ----------

Expand Down Expand Up @@ -210,7 +210,8 @@

# BlockedUnlockPickUp
register(
id="MiniGrid-BlockedUnlockPickUp", entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp"
id="MiniGrid-BlockedUnlockPickUp",
entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp",
)

# DoorKey
Expand Down
13 changes: 6 additions & 7 deletions src/xminigrid/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid"))

NAME2HFFILENAME = {
"trivial-1m": "trivial_1m",
"small-1m": "small_1m",
"small-dist-1m": "small_dist_1m",
"medium-1m": "medium_1m_v1",
"medium-3m": "medium_3m_v1",
"high-1m": "high_1m",
"high-3m": "high_3m",
"trivial-1m": "trivial_1m_v2",
"small-1m": "small_1m_v2",
"medium-1m": "medium_1m_v2",
"medium-3m": "medium_3m_v2",
"high-1m": "high_1m_v2",
"high-3m": "high_3m_v2",
}


Expand Down
66 changes: 28 additions & 38 deletions src/xminigrid/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,44 @@
import jax.numpy as jnp
from flax import struct

NUM_ACTIONS = 6

# GRID: [tile, color]
NUM_LAYERS = 2
NUM_TILES = 15
NUM_COLORS = 14
NUM_ACTIONS = 6
NUM_TILES = 13
NUM_COLORS = 12


# TODO: do we really need END_OF_MAP? seem like unseen can be used instead...
# enums, kinda...
class Tiles(struct.PyTreeNode):
EMPTY: int = struct.field(pytree_node=False, default=0)
END_OF_MAP: int = struct.field(pytree_node=False, default=1)
UNSEEN: int = struct.field(pytree_node=False, default=2)
FLOOR: int = struct.field(pytree_node=False, default=3)
WALL: int = struct.field(pytree_node=False, default=4)
BALL: int = struct.field(pytree_node=False, default=5)
SQUARE: int = struct.field(pytree_node=False, default=6)
PYRAMID: int = struct.field(pytree_node=False, default=7)
GOAL: int = struct.field(pytree_node=False, default=8)
KEY: int = struct.field(pytree_node=False, default=9)
DOOR_LOCKED: int = struct.field(pytree_node=False, default=10)
DOOR_CLOSED: int = struct.field(pytree_node=False, default=11)
DOOR_OPEN: int = struct.field(pytree_node=False, default=12)
HEX: int = struct.field(pytree_node=False, default=13)
STAR: int = struct.field(pytree_node=False, default=14)
FLOOR: int = struct.field(pytree_node=False, default=1)
WALL: int = struct.field(pytree_node=False, default=2)
BALL: int = struct.field(pytree_node=False, default=3)
SQUARE: int = struct.field(pytree_node=False, default=4)
PYRAMID: int = struct.field(pytree_node=False, default=5)
GOAL: int = struct.field(pytree_node=False, default=6)
KEY: int = struct.field(pytree_node=False, default=7)
DOOR_LOCKED: int = struct.field(pytree_node=False, default=8)
DOOR_CLOSED: int = struct.field(pytree_node=False, default=9)
DOOR_OPEN: int = struct.field(pytree_node=False, default=10)
HEX: int = struct.field(pytree_node=False, default=11)
STAR: int = struct.field(pytree_node=False, default=12)


class Colors(struct.PyTreeNode):
EMPTY: int = struct.field(pytree_node=False, default=0)
END_OF_MAP: int = struct.field(pytree_node=False, default=1)
UNSEEN: int = struct.field(pytree_node=False, default=2)
RED: int = struct.field(pytree_node=False, default=3)
GREEN: int = struct.field(pytree_node=False, default=4)
BLUE: int = struct.field(pytree_node=False, default=5)
PURPLE: int = struct.field(pytree_node=False, default=6)
YELLOW: int = struct.field(pytree_node=False, default=7)
GREY: int = struct.field(pytree_node=False, default=8)
BLACK: int = struct.field(pytree_node=False, default=9)
ORANGE: int = struct.field(pytree_node=False, default=10)
WHITE: int = struct.field(pytree_node=False, default=11)
BROWN: int = struct.field(pytree_node=False, default=12)
PINK: int = struct.field(pytree_node=False, default=13)
RED: int = struct.field(pytree_node=False, default=1)
GREEN: int = struct.field(pytree_node=False, default=2)
BLUE: int = struct.field(pytree_node=False, default=3)
PURPLE: int = struct.field(pytree_node=False, default=4)
YELLOW: int = struct.field(pytree_node=False, default=5)
GREY: int = struct.field(pytree_node=False, default=6)
BLACK: int = struct.field(pytree_node=False, default=7)
ORANGE: int = struct.field(pytree_node=False, default=8)
WHITE: int = struct.field(pytree_node=False, default=9)
BROWN: int = struct.field(pytree_node=False, default=10)
PINK: int = struct.field(pytree_node=False, default=11)


# Only ~100 combinations so far, better to preallocate them
Expand All @@ -65,7 +61,6 @@ class Colors(struct.PyTreeNode):

WALKABLE = jnp.array(
(
Tiles.EMPTY,
Tiles.FLOOR,
Tiles.GOAL,
Tiles.DOOR_OPEN,
Expand All @@ -83,12 +78,7 @@ class Colors(struct.PyTreeNode):
)
)

FREE_TO_PUT_DOWN = jnp.array(
(
Tiles.EMPTY,
Tiles.FLOOR,
)
)
FREE_TO_PUT_DOWN = jnp.array((Tiles.FLOOR,))

LOS_BLOCKING = jnp.array(
(
Expand Down
4 changes: 2 additions & 2 deletions src/xminigrid/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Union
from typing import Callable

import jax
import jax.numpy as jnp
Expand All @@ -22,7 +22,7 @@ def equal(tile1: Tile, tile2: Tile) -> Tile:

def get_neighbouring_tiles(grid: GridState, y: IntOrArray, x: IntOrArray) -> tuple[Tile, Tile, Tile, Tile]:
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
end_of_map = Tiles.END_OF_MAP
end_of_map = Tiles.EMPTY

up_tile = grid.at[y - 1, x].get(mode="fill", fill_value=end_of_map)
right_tile = grid.at[y, x + 1].get(mode="fill", fill_value=end_of_map)
Expand Down
6 changes: 3 additions & 3 deletions src/xminigrid/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def crop_field_of_view(grid: GridState, agent: AgentState, height: int, width: i
grid = jnp.pad(
grid,
pad_width=((height, height), (width, width), (0, 0)),
constant_values=Tiles.END_OF_MAP,
constant_values=Tiles.EMPTY,
)
# account for padding
y = agent.position[0] + height
Expand Down Expand Up @@ -110,8 +110,8 @@ def minigrid_field_of_view(grid: GridState, agent: AgentState, height: int, widt
fov_grid = crop_field_of_view(grid, agent, height, width)
fov_grid = align_with_up(fov_grid, agent.direction)
mask = generate_viz_mask_minigrid(fov_grid)
# set UNSEEN value for all layers (including colors, as UNSEEN color has same id value)
fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.UNSEEN)
# set EMPTY as unseen value for all layers (including colors, as EMPTY color has same id value)
fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.EMPTY)

# TODO: should we even do this? Agent with good memory can remember what he picked up.
# WARN: this can overwrite tile the agent is on, GOAL for example.
Expand Down
Loading
Loading