Skip to content

Commit

Permalink
added new rules for the extended benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 13, 2023
1 parent c6eebb1 commit c27d2f5
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ install the source as follows:
git clone git@github.com:corl-team/xland-minigrid.git
cd xland-minigrid
# additional dependencies for baselines
pip install -e ".[benchmark]"
pip install -e ".[dev,benchmark]"
```
Note that the installation of JAX may differ depending on your hardware accelerator!
We advise users to explicitly install the correct JAX version (see the [official installation guide](https://github.com/google/jax#installation)).
Expand Down
9 changes: 5 additions & 4 deletions scripts/benchmark_xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import jax
import jax.tree_util as jtu
import numpy as np
import xminigrid
from xminigrid import load_benchmark
Expand Down Expand Up @@ -49,7 +50,7 @@ def _body_fn(timestep, action):

# see https://stackoverflow.com/questions/56763416/what-is-diffrence-between-number-and-repeat-in-python-timeit
# on why we divide by args.num_iter
def timeit_benchmark(benchmark_fn):
def timeit_benchmark(args, benchmark_fn):
t = time.time()
benchmark_fn().state.grid.block_until_ready()
print(f"Compilation time: {time.time() - t}")
Expand Down Expand Up @@ -85,15 +86,15 @@ def timeit_benchmark(benchmark_fn):
pmap_keys = jax.random.split(key, num=num_devices)

# benchmarking
elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_single, key))
elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_single, key))
single_fps = args.timesteps / elapsed_time
print(f"Single env, Elapsed time: {elapsed_time:.5f}s, FPS: {single_fps:.0f}")
print()
elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_vmap, key))
elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_vmap, key))
vmap_fps = (args.timesteps * args.num_envs) / elapsed_time
print(f"Vmap env, Elapsed time: {elapsed_time:.5f}s, FPS: {vmap_fps:.0f}")
print()
elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_pmap, pmap_keys))
elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys))
pmap_fps = (args.timesteps * args.num_envs) / elapsed_time
print(f"Pmap env, Elapsed time: {elapsed_time:.5f}s, FPS: {pmap_fps:.0f}")
print()
Expand Down
91 changes: 91 additions & 0 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Performance benchmark for all environments. For the paper and to check regressions after new features.
import argparse
import pprint
import timeit
from typing import Optional

import jax
import jax.tree_util as jtu
import numpy as np
import xminigrid
from tqdm.auto import tqdm
from xminigrid import load_benchmark
from xminigrid.wrappers import GymAutoResetWrapper

jax.config.update("jax_threefry_partitionable", True)

NUM_ENVS = (512, 1024, 2048, 4096, 8192)

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark-id", type=str, default="Trivial")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)
# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
env_params = env_params.replace(ruleset=ruleset)

def benchmark_fn(key):
def _body_fn(timestep, action):
new_timestep = jax.vmap(env.step, in_axes=(None, 0, 0))(env_params, timestep, action)
return new_timestep, None

key, actions_key = jax.random.split(key)
keys = jax.random.split(key, num=num_envs)
actions = jax.random.randint(
actions_key, shape=(timesteps, num_envs), minval=0, maxval=env.num_actions(env_params)
)

timestep = jax.vmap(env.reset, in_axes=(None, 0))(env_params, keys)
# unroll can affect FPS greatly !!!
timestep = jax.lax.scan(_body_fn, timestep, actions, unroll=1)[0]
return timestep

return benchmark_fn


# see https://stackoverflow.com/questions/56763416/what-is-diffrence-between-number-and-repeat-in-python-timeit
# on why we divide by args.num_iter
def timeit_benchmark(args, benchmark_fn):
benchmark_fn().state.grid.block_until_ready()
times = timeit.repeat(
lambda: benchmark_fn().state.grid.block_until_ready(),
number=args.num_iter,
repeat=args.num_repeat,
)
times = np.array(times) / args.num_iter
elapsed_time = np.max(times)
return elapsed_time


# that can take a while!
if __name__ == "__main__":
num_devices = jax.local_device_count()
args = parser.parse_args()
print("Num devices:", num_devices)

summary = {}
for num_envs in tqdm(NUM_ENVS, desc="Benchmark", leave=False):
results = {}
for env_id in tqdm(xminigrid.registered_environments(), desc="Envs.."):
assert num_envs % num_devices == 0
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

# benchmarking
pmap_keys = jax.random.split(jax.random.PRNGKey(0), num=num_devices)

elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys))
pmap_fps = (args.timesteps * num_envs) // elapsed_time

results[env_id] = int(pmap_fps)
summary[num_envs] = results

pprint.pprint(summary)
Loading

0 comments on commit c27d2f5

Please sign in to comment.