Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dask instead of multiprocessing for lunar lander tutorial (#346) #347

Merged
merged 12 commits into from
Aug 21, 2023
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# History

## (Upcoming)

### Changelog

#### Documentation

- Use dask instead of multiprocessing for lunar lander tutorial (#346)
- pip install swig before gymnasium[box2d] in lunar lander tutorial (#346)

## 0.5.2

This release contains miscellaneous edits to our documentation from v0.5.1.
Expand Down
16 changes: 11 additions & 5 deletions tutorials/lunar_lander.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@
},
"outputs": [],
"source": [
"%pip install ribs[visualize] gymnasium[box2d]==0.27.0 \"moviepy>=1.0.0\"\n",
"%pip install swig # Must be installed before box2d\n",
"%pip install ribs[visualize] gymnasium[box2d]==0.27.0 \"moviepy>=1.0.0\" dask distributed\n",
"\n",
"# An uninstalled version of decorator is occasionally loaded. This loads the\n",
"# newly installed version of decorator so that moviepy works properly -- see\n",
Expand All @@ -178,10 +179,10 @@
},
"outputs": [],
"source": [
"import multiprocessing\n",
"import sys\n",
"import time\n",
"\n",
"from dask.distributed import Client\n",
"import gymnasium as gym\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand Down Expand Up @@ -460,7 +461,7 @@
"\n",
"With the pyribs components defined, we start searching with CMA-ME. Since we use 5 emitters each with a batch size of 30 and we run 300 iterations, we run 5 x 30 x 300 = 45,000 lunar lander simulations. We also keep track of some logging info via `archive.stats`, which is an [`ArchiveStats`](https://docs.pyribs.org/en/latest/api/ribs.archives.ArchiveStats.html) object.\n",
"\n",
"Since it takes a relatively long time to evaluate a lunar lander solution, we parallelize the evaluation of multiple solutions with Python's [multiprocessing module](https://docs.python.org/3/library/multiprocessing.html), specifically the [`starmap`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.starmap) method of [`multiprocessing.Pool`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool). With one worker, the loop should take **2 hours** to run. With two workers, it should take **1 hour** to run. Feel free to increase the number of workers based on the number of CPUs your system has available to speed up the loop further."
"Since it takes a relatively long time to evaluate a lunar lander solution, we parallelize the evaluation of multiple solutions with [Dask](https://distributed.dask.org/en/stable/quickstart.html). With one worker (i.e., one CPU), the loop should take **2 hours** to run. With two workers, it should take **1 hour** to run. Feel free to increase the number of workers based on the number of CPUs your system has available to speed up the loop further."
]
},
{
Expand Down Expand Up @@ -559,13 +560,18 @@
"total_itrs = 300\n",
"workers = 2 # Adjust the number of workers based on your available CPUs.\n",
"\n",
"client = Client(\n",
" n_workers=workers, # Create this many worker processes using Dask LocalCluster.\n",
" threads_per_worker=1, # Each worker process is single-threaded.\n",
")\n",
"\n",
"for itr in trange(1, total_itrs + 1, file=sys.stdout, desc='Iterations'):\n",
" # Request models from the scheduler.\n",
" sols = scheduler.ask()\n",
"\n",
" # Evaluate the models and record the objectives and measuress.\n",
" with multiprocessing.Pool(workers) as pool:\n",
" results = pool.starmap(simulate, [(model, env_seed) for model in sols])\n",
" futures = client.map(lambda model: simulate(model, env_seed), sols)\n",
" results = client.gather(futures)\n",
"\n",
" objs, meas = [], []\n",
" for obj, impact_x_pos, impact_y_vel in results:\n",
Expand Down
Loading