You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: docs/tutorials/distributed_computation.rst
+4-10
Original file line number
Diff line number
Diff line change
@@ -3,13 +3,8 @@ Distributed computation
3
3
4
4
We present here how to perform computation on multiple devices.
5
5
6
-
Imagine to have at your disposal 4 GPUs and you want to distribute the workload on them.
7
-
There are two ways to do so:
8
-
9
-
* Create 4 simulators, specifying a different device for each one
10
-
* Use the `JAX pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`_ function to wrap the functions you need.
11
-
12
-
If memory is not an issue, the second method is the easiest one. In fact, you simply need divide your population in groups (i.e. divide the first axis) and distribute over the groups.
6
+
Consider a scenario where you have access to four GPUs and aim to distribute the workload effectively among them.
7
+
To achieve this, we employ the `JAX pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`_ function, which allows seamless distribution of functions across multiple accelerators.
13
8
14
9
.. code-block:: python
15
10
@@ -21,7 +16,7 @@ If memory is not an issue, the second method is the easiest one. In fact, you si
21
16
# load 200 individuals
22
17
population = simulator.load_population(genome)[:200]
23
18
# divide them in 4 groups
24
-
population = population.reshape(4, 50, *population.shape[1:])
19
+
population = population.reshape(4, -1, *population.shape[1:])
25
20
26
21
# prepare a parallelized function over groups
27
22
pmap_dh = jax.pmap(
@@ -35,7 +30,6 @@ If memory is not an issue, the second method is the easiest one. In fact, you si
35
30
dh_pop = dh_pop.reshape(-1, *dh_pop.shape[2:])
36
31
37
32
38
-
39
33
If you want to perform random crosses or full diallel, grouping the population will change the semantics (the random crosses or the full diallel will be performed by group independently).
40
34
In this case, you should use the function ``cross`` after generating the proper array of parents.
41
35
For example, to perform random crosses:
@@ -51,7 +45,7 @@ For example, to perform random crosses:
0 commit comments