@@ -25,7 +25,7 @@ def cross(
25
25
:param recombination_vec: array of m probabilities.
26
26
The i-th value represent the probability to recombine before the marker i.
27
27
:type recombination_vec:
28
- :param random_key: JAX Array with dtype jax.dtypes.prng_key , for reproducibility purpose.
28
+ :param random_key: JAX random key , for reproducibility purpose.
29
29
:type random_key: jax.Array
30
30
:param mutation_probability: The probability of having a mutation in a marker.
31
31
:type mutation_probability: float
@@ -43,7 +43,7 @@ def cross(
43
43
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
44
44
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
45
45
>>> rec_vec = rec_vec.flatten()
46
- >>> random_key = jax.random.PRNGKey (42)
46
+ >>> random_key = jax.random.key (42)
47
47
>>> f2 = functional.cross(parents, rec_vec, random_key)
48
48
>>> f2.shape
49
49
(50, 1000, 2)
@@ -52,7 +52,7 @@ def cross(
52
52
random_keys = jax .random .split (
53
53
random_key , num = 2 * len (parents ) * 2 * parents .shape [3 ]
54
54
)
55
- random_keys = random_keys .reshape (2 , len (parents ), 2 , parents .shape [3 ], 2 )
55
+ random_keys = random_keys .reshape (2 , len (parents ), 2 , parents .shape [3 ])
56
56
cross_random_key , mutate_random_key = random_keys
57
57
58
58
offsprings = _cross (
@@ -100,8 +100,8 @@ def double_haploid(
100
100
:param recombination_vec: array of m probabilities.
101
101
The i-th value represent the probability to recombine before the marker i.
102
102
:type recombination_vec: ndarray
103
- :param random_key: array of n PRNGKey, one for each individual .
104
- :type random_key: jax.Array with dtype jax.dtypes.prng_key
103
+ :param random_key: JAX random key, for reproducibility purpose .
104
+ :type random_key: jax.Array
105
105
:param mutation_probability: The probability of having a mutation in a marker.
106
106
:type mutation_probability: float
107
107
:return: output population of shape (n, n_offspring, m, d).
@@ -118,15 +118,15 @@ def double_haploid(
118
118
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
119
119
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
120
120
>>> rec_vec = rec_vec.flatten()
121
- >>> random_key = jax.random.PRNGKey (42)
121
+ >>> random_key = jax.random.key (42)
122
122
>>> dh = functional.double_haploid(f1, 10, rec_vec, random_key)
123
123
>>> dh.shape
124
124
(50, 10, 1000, 2)
125
125
"""
126
126
population = population .reshape (* population .shape [:2 ], - 1 , 2 )
127
127
keys = jax .random .split (
128
128
random_key , num = 2 * len (population ) * n_offspring * population .shape [2 ]
129
- ).reshape (2 , len (population ), n_offspring , population .shape [2 ], 2 )
129
+ ).reshape (2 , len (population ), n_offspring , population .shape [2 ])
130
130
cross_random_key , mutate_random_key = keys
131
131
haploids = _double_haploid (
132
132
population ,
0 commit comments