Skip to content

Commit 768881a

Browse files
committed
distributed tutorial
1 parent 8f14e77 commit 768881a

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

docs/tutorials/distributed_computation.rst

+4-10
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,8 @@ Distributed computation
33

44
We present here how to perform computation on multiple devices.
55

6-
Imagine to have at your disposal 4 GPUs and you want to distribute the workload on them.
7-
There are two ways to do so:
8-
9-
* Create 4 simulators, specifying a different device for each one
10-
* Use the `JAX pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`_ function to wrap the functions you need.
11-
12-
If memory is not an issue, the second method is the easiest one. In fact, you simply need divide your population in groups (i.e. divide the first axis) and distribute over the groups.
6+
Consider a scenario where you have access to four GPUs and aim to distribute the workload effectively among them.
7+
To achieve this, we employ the `JAX pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`_ function, which allows seamless distribution of functions across multiple accelerators.
138

149
.. code-block:: python
1510
@@ -21,7 +16,7 @@ If memory is not an issue, the second method is the easiest one. In fact, you si
2116
# load 200 individuals
2217
population = simulator.load_population(genome)[:200]
2318
# divide them in 4 groups
24-
population = population.reshape(4, 50, *population.shape[1:])
19+
population = population.reshape(4, -1, *population.shape[1:])
2520
2621
# prepare a parallelized function over groups
2722
pmap_dh = jax.pmap(
@@ -35,7 +30,6 @@ If memory is not an issue, the second method is the easiest one. In fact, you si
3530
dh_pop = dh_pop.reshape(-1, *dh_pop.shape[2:])
3631
3732
38-
3933
If you want to perform random crosses or full diallel, grouping the population will change the semantics (the random crosses or the full diallel will be performed by group independently).
4034
In this case, you should use the function ``cross`` after generating the proper array of parents.
4135
For example, to perform random crosses:
@@ -51,7 +45,7 @@ For example, to perform random crosses:
5145
5246
random_indices = np.random.random_integers(0, len(population) - 1, size=(200, 2))
5347
parents = population[random_indices]
54-
parents = parents.reshape(4, 50, *parents.shape[1:])
48+
parents = parents.reshape(4, -1, *parents.shape[1:])
5549
pmap_cross = jax.pmap(simulator.cross,)
5650
new_pop = pmap_cross(parents)
5751
new_pop = new_pop.reshape(-1, *new_pop.shape[2:])

0 commit comments

Comments
 (0)