From c27d2f53cb0d1c99c13e207ed1fc4e3d6729715b Mon Sep 17 00:00:00 2001 From: Howuhh Date: Wed, 13 Dec 2023 21:57:46 +0300 Subject: [PATCH] added new rules for the extended benchmark --- README.md | 2 +- scripts/benchmark_xland.py | 9 +- scripts/benchmark_xland_all.py | 91 ++++++ src/xminigrid/core/rules.py | 284 +++++++++++++++++- .../envs/minigrid/blockedunlockpickup.py | 2 + src/xminigrid/envs/minigrid/playground.py | 2 +- src/xminigrid/envs/minigrid/unlockpickup.py | 2 + 7 files changed, 385 insertions(+), 7 deletions(-) create mode 100644 scripts/benchmark_xland_all.py diff --git a/README.md b/README.md index 1faf05a..f12d6ce 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ install the source as follows: git clone git@github.com:corl-team/xland-minigrid.git cd xland-minigrid # additional dependencies for baselines -pip install -e ".[benchmark]" +pip install -e ".[dev,benchmark]" ``` Note that the installation of JAX may differ depending on your hardware accelerator! We advise users to explicitly install the correct JAX version (see the [official installation guide](https://github.com/google/jax#installation)). diff --git a/scripts/benchmark_xland.py b/scripts/benchmark_xland.py index bb49ded..a7bd68b 100644 --- a/scripts/benchmark_xland.py +++ b/scripts/benchmark_xland.py @@ -4,6 +4,7 @@ from typing import Optional import jax +import jax.tree_util as jtu import numpy as np import xminigrid from xminigrid import load_benchmark @@ -49,7 +50,7 @@ def _body_fn(timestep, action): # see https://stackoverflow.com/questions/56763416/what-is-diffrence-between-number-and-repeat-in-python-timeit # on why we divide by args.num_iter -def timeit_benchmark(benchmark_fn): +def timeit_benchmark(args, benchmark_fn): t = time.time() benchmark_fn().state.grid.block_until_ready() print(f"Compilation time: {time.time() - t}") @@ -85,15 +86,15 @@ def timeit_benchmark(benchmark_fn): pmap_keys = jax.random.split(key, num=num_devices) # benchmarking - elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_single, key)) + elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_single, key)) single_fps = args.timesteps / elapsed_time print(f"Single env, Elapsed time: {elapsed_time:.5f}s, FPS: {single_fps:.0f}") print() - elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_vmap, key)) + elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_vmap, key)) vmap_fps = (args.timesteps * args.num_envs) / elapsed_time print(f"Vmap env, Elapsed time: {elapsed_time:.5f}s, FPS: {vmap_fps:.0f}") print() - elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_pmap, pmap_keys)) + elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys)) pmap_fps = (args.timesteps * args.num_envs) / elapsed_time print(f"Pmap env, Elapsed time: {elapsed_time:.5f}s, FPS: {pmap_fps:.0f}") print() diff --git a/scripts/benchmark_xland_all.py b/scripts/benchmark_xland_all.py new file mode 100644 index 0000000..9a5f2d3 --- /dev/null +++ b/scripts/benchmark_xland_all.py @@ -0,0 +1,91 @@ +# Performance benchmark for all environments. For the paper and to check regressions after new features. +import argparse +import pprint +import timeit +from typing import Optional + +import jax +import jax.tree_util as jtu +import numpy as np +import xminigrid +from tqdm.auto import tqdm +from xminigrid import load_benchmark +from xminigrid.wrappers import GymAutoResetWrapper + +jax.config.update("jax_threefry_partitionable", True) + +NUM_ENVS = (512, 1024, 2048, 4096, 8192) + +parser = argparse.ArgumentParser() +parser.add_argument("--benchmark-id", type=str, default="Trivial") +parser.add_argument("--timesteps", type=int, default=1000) +parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats") +parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)") + + +def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None): + env, env_params = xminigrid.make(env_id) + env = GymAutoResetWrapper(env) + # choose XLand benchmark if needed + if "XLand-MiniGrid" in env_id and benchmark_id is not None: + ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) + env_params = env_params.replace(ruleset=ruleset) + + def benchmark_fn(key): + def _body_fn(timestep, action): + new_timestep = jax.vmap(env.step, in_axes=(None, 0, 0))(env_params, timestep, action) + return new_timestep, None + + key, actions_key = jax.random.split(key) + keys = jax.random.split(key, num=num_envs) + actions = jax.random.randint( + actions_key, shape=(timesteps, num_envs), minval=0, maxval=env.num_actions(env_params) + ) + + timestep = jax.vmap(env.reset, in_axes=(None, 0))(env_params, keys) + # unroll can affect FPS greatly !!! + timestep = jax.lax.scan(_body_fn, timestep, actions, unroll=1)[0] + return timestep + + return benchmark_fn + + +# see https://stackoverflow.com/questions/56763416/what-is-diffrence-between-number-and-repeat-in-python-timeit +# on why we divide by args.num_iter +def timeit_benchmark(args, benchmark_fn): + benchmark_fn().state.grid.block_until_ready() + times = timeit.repeat( + lambda: benchmark_fn().state.grid.block_until_ready(), + number=args.num_iter, + repeat=args.num_repeat, + ) + times = np.array(times) / args.num_iter + elapsed_time = np.max(times) + return elapsed_time + + +# that can take a while! +if __name__ == "__main__": + num_devices = jax.local_device_count() + args = parser.parse_args() + print("Num devices:", num_devices) + + summary = {} + for num_envs in tqdm(NUM_ENVS, desc="Benchmark", leave=False): + results = {} + for env_id in tqdm(xminigrid.registered_environments(), desc="Envs.."): + assert num_envs % num_devices == 0 + # building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps) + benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id) + benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) + + # benchmarking + pmap_keys = jax.random.split(jax.random.PRNGKey(0), num=num_devices) + + elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys)) + pmap_fps = (args.timesteps * num_envs) // elapsed_time + + results[env_id] = int(pmap_fps) + summary[num_envs] = results + + pprint.pprint(summary) diff --git a/src/xminigrid/core/rules.py b/src/xminigrid/core/rules.py index 7d2adbb..15cfbbc 100644 --- a/src/xminigrid/core/rules.py +++ b/src/xminigrid/core/rules.py @@ -7,9 +7,10 @@ from .constants import TILES_REGISTRY, Colors, Tiles from .grid import equal, get_neighbouring_tiles, pad_along_axis -MAX_RULE_ENCODING_LEN = 6 + 1 # for idx +MAX_RULE_ENCODING_LEN = 6 + 1 # +1 for idx +# this is very costly, will evaluate all under vmap. Submit a PR if you know how to do it better! def check_rule(encodings, grid, agent, action, position): def _check(carry, encoding): grid, agent = carry @@ -22,6 +23,14 @@ def _check(carry, encoding): lambda: AgentHoldRule.decode(encoding)(grid, agent, action, position), lambda: AgentNearRule.decode(encoding)(grid, agent, action, position), lambda: TileNearRule.decode(encoding)(grid, agent, action, position), + lambda: TileNearUpRule.decode(encoding)(grid, agent, action, position), + lambda: TileNearRightRule.decode(encoding)(grid, agent, action, position), + lambda: TileNearDownRule.decode(encoding)(grid, agent, action, position), + lambda: TileNearLeftRule.decode(encoding)(grid, agent, action, position), + lambda: AgentNearUpRule.decode(encoding)(grid, agent, action, position), + lambda: AgentNearRightRule.decode(encoding)(grid, agent, action, position), + lambda: AgentNearDownRule.decode(encoding)(grid, agent, action, position), + lambda: AgentNearLeftRule.decode(encoding)(grid, agent, action, position), ), ) return (grid, agent), None @@ -172,3 +181,276 @@ def decode(cls, encoding): def encode(self): encoding = jnp.hstack([jnp.asarray(3), self.tile_a, self.tile_b, self.prod_tile], dtype=jnp.uint8) return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +# tile_b should be one tile up near the tile_a +class TileNearUpRule(BaseRule): + tile_a: jax.Array + tile_b: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + tile = grid[position[0], position[1]] + + def _rule_fn(grid): + empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + y, x = position + up, _, down, _ = get_neighbouring_tiles(grid, y, x) + + grid = jax.lax.select( + equal(tile, self.tile_b) & equal(down, self.tile_a), + grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + grid = jax.lax.select( + equal(tile, self.tile_a) & equal(up, self.tile_b), + grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + return grid + + grid = jax.lax.cond( + jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)), + lambda: _rule_fn(grid), + lambda: grid, + ) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile_a=encoding[1:3], tile_b=encoding[3:5], prod_tile=encoding[5:7]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(4), self.tile_a, self.tile_b, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class TileNearRightRule(BaseRule): + tile_a: jax.Array + tile_b: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + tile = grid[position[0], position[1]] + + def _rule_fn(grid): + empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + y, x = position + _, right, _, left = get_neighbouring_tiles(grid, y, x) + + grid = jax.lax.select( + equal(tile, self.tile_b) & equal(left, self.tile_a), + grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + grid = jax.lax.select( + equal(tile, self.tile_a) & equal(right, self.tile_b), + grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + return grid + + grid = jax.lax.cond( + jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)), + lambda: _rule_fn(grid), + lambda: grid, + ) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile_a=encoding[1:3], tile_b=encoding[3:5], prod_tile=encoding[5:7]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(5), self.tile_a, self.tile_b, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class TileNearDownRule(BaseRule): + tile_a: jax.Array + tile_b: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + tile = grid[position[0], position[1]] + + def _rule_fn(grid): + empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + y, x = position + up, _, down, _ = get_neighbouring_tiles(grid, y, x) + + grid = jax.lax.select( + equal(tile, self.tile_b) & equal(up, self.tile_a), + grid.at[y - 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + grid = jax.lax.select( + equal(tile, self.tile_a) & equal(down, self.tile_b), + grid.at[y + 1, x].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + return grid + + grid = jax.lax.cond( + jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)), + lambda: _rule_fn(grid), + lambda: grid, + ) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile_a=encoding[1:3], tile_b=encoding[3:5], prod_tile=encoding[5:7]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(6), self.tile_a, self.tile_b, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class TileNearLeftRule(BaseRule): + tile_a: jax.Array + tile_b: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + tile = grid[position[0], position[1]] + + def _rule_fn(grid): + empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY] + y, x = position + _, right, _, left = get_neighbouring_tiles(grid, y, x) + + grid = jax.lax.select( + equal(tile, self.tile_b) & equal(right, self.tile_a), + grid.at[y, x + 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + grid = jax.lax.select( + equal(tile, self.tile_a) & equal(left, self.tile_b), + grid.at[y, x - 1].set(self.prod_tile).at[y, x].set(empty_tile), + grid, + ) + return grid + + grid = jax.lax.cond( + jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)), + lambda: _rule_fn(grid), + lambda: grid, + ) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile_a=encoding[1:3], tile_b=encoding[3:5], prod_tile=encoding[5:7]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(7), self.tile_a, self.tile_b, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class AgentNearUpRule(BaseRule): + tile: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + def _rule_fn(grid): + y, x = agent.position + up, _, _, _ = get_neighbouring_tiles(grid, y, x) + grid = jax.lax.select( + equal(up, self.tile), + grid.at[y - 1, x].set(self.prod_tile), + grid, + ) + return grid + + grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile=encoding[1:3], prod_tile=encoding[3:5]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(8), self.tile, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class AgentNearRightRule(BaseRule): + tile: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + def _rule_fn(grid): + y, x = agent.position + _, right, _, _ = get_neighbouring_tiles(grid, y, x) + grid = jax.lax.select( + equal(right, self.tile), + grid.at[y, x + 1].set(self.prod_tile), + grid, + ) + return grid + + grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile=encoding[1:3], prod_tile=encoding[3:5]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(9), self.tile, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class AgentNearDownRule(BaseRule): + tile: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + def _rule_fn(grid): + y, x = agent.position + _, _, down, _ = get_neighbouring_tiles(grid, y, x) + grid = jax.lax.select( + equal(down, self.tile), + grid.at[y + 1, x].set(self.prod_tile), + grid, + ) + return grid + + grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile=encoding[1:3], prod_tile=encoding[3:5]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(10), self.tile, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) + + +class AgentNearLeftRule(BaseRule): + tile: jax.Array + prod_tile: jax.Array + + def __call__(self, grid, agent, action, position): + def _rule_fn(grid): + y, x = agent.position + _, _, _, left = get_neighbouring_tiles(grid, y, x) + grid = jax.lax.select( + equal(left, self.tile), + grid.at[y, x - 1].set(self.prod_tile), + grid, + ) + return grid + + grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid) + return grid, agent + + @classmethod + def decode(cls, encoding): + return cls(tile=encoding[1:3], prod_tile=encoding[3:5]) + + def encode(self): + encoding = jnp.hstack([jnp.asarray(11), self.tile, self.prod_tile], dtype=jnp.uint8) + return pad_along_axis(encoding, MAX_RULE_ENCODING_LEN) diff --git a/src/xminigrid/envs/minigrid/blockedunlockpickup.py b/src/xminigrid/envs/minigrid/blockedunlockpickup.py index 40a542c..ae557ae 100644 --- a/src/xminigrid/envs/minigrid/blockedunlockpickup.py +++ b/src/xminigrid/envs/minigrid/blockedunlockpickup.py @@ -24,6 +24,8 @@ Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, + Tiles.HEX, + Tiles.STAR, ) ) _rule_encoding = EmptyRule().encode()[None, ...] diff --git a/src/xminigrid/envs/minigrid/playground.py b/src/xminigrid/envs/minigrid/playground.py index 9d145ef..994347d 100644 --- a/src/xminigrid/envs/minigrid/playground.py +++ b/src/xminigrid/envs/minigrid/playground.py @@ -30,7 +30,7 @@ _allowed_colors, ) _allowed_objects = cartesian_product_1d( - jnp.array((Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.GOAL), dtype=jnp.uint8), + jnp.array((Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL), dtype=jnp.uint8), _allowed_colors, ) # number of doors with 9 rooms diff --git a/src/xminigrid/envs/minigrid/unlockpickup.py b/src/xminigrid/envs/minigrid/unlockpickup.py index 1fc1443..f98c324 100644 --- a/src/xminigrid/envs/minigrid/unlockpickup.py +++ b/src/xminigrid/envs/minigrid/unlockpickup.py @@ -24,6 +24,8 @@ Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, + Tiles.STAR, + Tiles.HEX, ) ) _rule_encoding = EmptyRule().encode()[None, ...]