Skip to content

Commit

Permalink
tested on colab
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 3, 2023
1 parent 011a025 commit b40cff5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ python scripts/ruleset_generator.py --help

In depth description of all available benchmarks is provided [here (soon)]().

**P.S.** Currently only one benchmark is available. We will release more after some testing and configs balancing. Stay tuned!

## Environments 🌍

We provide environments from two domains. `XLand` is our main focus for meta-learning. For this domain we provide single
Expand Down
35 changes: 29 additions & 6 deletions examples/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"\n",
"Welcome to the walkthrough the XLand-MiniGrid library. This notebook will showcase our environments and bechmarks APIs, explaning the details and our motivations. It will also provide vectorization and multi-device parallelization examples. For full baselines training demo see our notebooks with standalone PPO implementations.\n",
"\n",
"⚠️ Ensure you select a GPU/TPU from `Runtime > Change runtime type` ⚠️"
"> ⚠️ Ensure you select a GPU from `Runtime > Change runtime type` ⚠️"
]
},
{
Expand All @@ -47,15 +47,19 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install 'xminigrid[baselines]' "
"# jax is already installed on the colab, uncomment only if needed\n",
"# !pip install --upgrade \"jax[cuda11_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"!pip install xminigrid"
]
},
{
"cell_type": "markdown",
"id": "4c2c028d-96fb-4413-80f5-f504285d8f42",
"metadata": {},
"source": [
"> ⚠️ Colab TPU runtime is not compatible with new JAX versions (> 0.4.0). Please, use kaggle notebooks if you want to use TPUs. There is no quick way to open notebook from github in kaggle (like colab badge above), so you will need to manually upload it. ⚠️\n",
"\n",
"If you have chosen the TPU runtime, some additional actions are needed. Please, uncomment these and setup the TPU through JAX. Check that devices are `TpuDevice`."
]
},
Expand Down Expand Up @@ -411,7 +415,7 @@
"source": [
"from IPython.display import Video\n",
"\n",
"Video(\"example_rollout.mp4\")"
"Video(\"example_rollout.mp4\", embed=True)"
]
},
{
Expand Down Expand Up @@ -535,7 +539,7 @@
"outputs": [],
"source": [
"num_devices = jax.local_device_count()\n",
"envs_range = [128, 256, 512, 1024, 2048]\n",
"envs_range = [512, 1024, 2048, 4096, 8192]\n",
"print(\"Num devices for pmap:\", num_devices)\n",
"\n",
"vmap_stats, pmap_stats = [], []\n",
Expand All @@ -560,6 +564,25 @@
" pmap_stats.append(pmap_fps)"
]
},
{
"cell_type": "markdown",
"id": "eee66c3c-264c-4223-87f9-7e38d99bbf98",
"metadata": {},
"source": [
"Note that the actual values may differ from the reported in the paper, as the default colab GPU is less powerfull that A100 used in the main experiments. But even on the one T4 GPU users can expect to get ~**10M** steps per second (and ~**25M** with two T4 GPUs on kaggle notebooks with `pmap`)!! However, compared to A100 scaling will saturate a lot earilier with respect to number of parallel environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9001dd7-90dd-4917-a53a-9b9aed2fcd45",
"metadata": {},
"outputs": [],
"source": [
"for n, vfps, pfps in zip(envs_range, vmap_stats, pmap_stats):\n",
" print(f\"{n} envs. vmap fps: {vfps}, pmap fps: {pfps}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1002,7 +1025,7 @@
"id": "a0b55667-d2c3-4b5c-b074-d98c0b14abc9",
"metadata": {},
"source": [
"As we vmapped only on rulesets and not the keys, initial positions for agent and objects are the same. In practice, we also vmap on keys."
"As we vmapped only on rulesets and not the keys, initial positions for agent and objects are the same. In practice, we also vmap on keys to randomize their positions too."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/envs/xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def time_limit(self, params: XLandMiniGridEnvOptions) -> int:

def _generate_problem(self, params: XLandMiniGridEnvOptions, key: jax.Array) -> State:
# WARN: we can make this compatible with jit (to vmap on different layouts during training),
# but it will be very costly, as lax.switch will generate all layouts during reset under vmap
# but it will probably be very costly, as lax.switch will generate all layouts during reset under vmap
# TODO: experiment with this under jit, does it possible to make it jit-compatible without overhead?
if params.grid_type == "R1":
key, grid = generate_room(key, params.height, params.width)
Expand Down

0 comments on commit b40cff5

Please sign in to comment.