Skip to content

Commit

Permalink
img obs benchmarking, fixed cache bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Mar 2, 2024
1 parent 50d61b5 commit 2b0f87b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 8 deletions.
24 changes: 20 additions & 4 deletions scripts/benchmark_xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,29 @@
parser = argparse.ArgumentParser()
parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-16x16")
parser.add_argument("--benchmark-id", type=str, default="Trivial")
parser.add_argument("--img-obs", action="store_true")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-envs", type=int, default=8192)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
def build_benchmark(
env_id: str,
num_envs: int,
timesteps: int,
benchmark_id: Optional[str] = None,
img_obs: bool = False,
):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)

# enable img observations if needed
if img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -73,13 +87,15 @@ def timeit_benchmark(args, benchmark_fn):
print("Num devices for pmap:", num_devices)

# building for single env benchmarking
benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id)
benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id, args.img_obs)
benchmark_fn_single = jax.jit(benchmark_fn_single)
# building vmap for vectorization benchmarking
benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id)
benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id, args.img_obs)
benchmark_fn_vmap = jax.jit(benchmark_fn_vmap)
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = build_benchmark(
args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

key = jax.random.PRNGKey(0)
Expand Down
16 changes: 14 additions & 2 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,24 @@

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark-id", type=str, default="trivial-1m")
parser.add_argument("--img-obs", action="store_true")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
def build_benchmark(
env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None, img_obs: bool = False
):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)

# enable img observations if needed
if img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -77,7 +87,9 @@ def timeit_benchmark(args, benchmark_fn):
for env_id in tqdm(environments, desc="Envs.."):
assert num_envs % num_devices == 0
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = build_benchmark(
env_id, num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

# benchmarking
Expand Down
1 change: 0 additions & 1 deletion scripts/generate_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python scripts/ruleset_generator.py \
--total_rulesets=1_000_000 \
--save_path="trivial_1m"


# small
python scripts/ruleset_generator.py \
--prune_chain \
Expand Down
3 changes: 2 additions & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@

# BlockedUnlockPickUp
register(
id="MiniGrid-BlockedUnlockPickUp", entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp"
id="MiniGrid-BlockedUnlockPickUp",
entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp",
)

# DoorKey
Expand Down
2 changes: 2 additions & 0 deletions src/xminigrid/experimental/img_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.

cache_path = os.path.join(CACHE_PATH, "render_cache")
if not os.path.exists(cache_path):
os.makedirs(CACHE_PATH, exist_ok=True)

TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE)
TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3)
TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3)
Expand Down

0 comments on commit 2b0f87b

Please sign in to comment.