13
13
def cross (
14
14
parents : Parents ["n" ],
15
15
recombination_vec : Float [Array , N_MARKERS ],
16
- random_key : jax .random . PRNGKeyArray ,
16
+ random_key : jax .Array ,
17
17
mutation_probability : float = 0.0 ,
18
18
) -> Population ["n" ]:
19
19
"""Main function that computes crosses from a list of parents.
@@ -25,8 +25,8 @@ def cross(
25
25
:param recombination_vec: array of m probabilities.
26
26
The i-th value represent the probability to recombine before the marker i.
27
27
:type recombination_vec:
28
- :param random_key: JAX PRNGKey , for reproducibility purpose.
29
- :type random_key: jax.random.PRNGKeyArray
28
+ :param random_key: JAX random key , for reproducibility purpose.
29
+ :type random_key: jax.Array
30
30
:param mutation_probability: The probability of having a mutation in a marker.
31
31
:type mutation_probability: float
32
32
:return: offspring population of shape (n, m, d).
@@ -43,7 +43,7 @@ def cross(
43
43
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
44
44
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
45
45
>>> rec_vec = rec_vec.flatten()
46
- >>> random_key = jax.random.PRNGKey (42)
46
+ >>> random_key = jax.random.key (42)
47
47
>>> f2 = functional.cross(parents, rec_vec, random_key)
48
48
>>> f2.shape
49
49
(50, 1000, 2)
@@ -52,7 +52,7 @@ def cross(
52
52
random_keys = jax .random .split (
53
53
random_key , num = 2 * len (parents ) * 2 * parents .shape [3 ]
54
54
)
55
- random_keys = random_keys .reshape (2 , len (parents ), 2 , parents .shape [3 ], 2 )
55
+ random_keys = random_keys .reshape (2 , len (parents ), 2 , parents .shape [3 ])
56
56
cross_random_key , mutate_random_key = random_keys
57
57
58
58
offsprings = _cross (
@@ -71,8 +71,8 @@ def cross(
71
71
def _cross (
72
72
parent : Individual ,
73
73
recombination_vec : Float [Array , N_MARKERS ],
74
- cross_random_key : jax .random . PRNGKeyArray ,
75
- mutate_random_key : jax .random . PRNGKeyArray ,
74
+ cross_random_key : jax .Array ,
75
+ mutate_random_key : jax .Array ,
76
76
mutation_probability : float ,
77
77
) -> Haploid :
78
78
return _meiosis (
@@ -88,7 +88,7 @@ def double_haploid(
88
88
population : Population ["n" ],
89
89
n_offspring : int ,
90
90
recombination_vec : Float [Array , N_MARKERS ],
91
- random_key : jax .random . PRNGKeyArray ,
91
+ random_key : jax .Array ,
92
92
mutation_probability : float = 0.0 ,
93
93
) -> Population ["n n_offspring" ]:
94
94
"""Computes the double haploid of the input population.
@@ -100,8 +100,8 @@ def double_haploid(
100
100
:param recombination_vec: array of m probabilities.
101
101
The i-th value represent the probability to recombine before the marker i.
102
102
:type recombination_vec: ndarray
103
- :param random_key: array of n PRNGKey, one for each individual .
104
- :type random_key: jax.random.PRNGKeyArray
103
+ :param random_key: JAX random key, for reproducibility purpose .
104
+ :type random_key: jax.Array
105
105
:param mutation_probability: The probability of having a mutation in a marker.
106
106
:type mutation_probability: float
107
107
:return: output population of shape (n, n_offspring, m, d).
@@ -118,15 +118,15 @@ def double_haploid(
118
118
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
119
119
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
120
120
>>> rec_vec = rec_vec.flatten()
121
- >>> random_key = jax.random.PRNGKey (42)
121
+ >>> random_key = jax.random.key (42)
122
122
>>> dh = functional.double_haploid(f1, 10, rec_vec, random_key)
123
123
>>> dh.shape
124
124
(50, 10, 1000, 2)
125
125
"""
126
126
population = population .reshape (* population .shape [:2 ], - 1 , 2 )
127
127
keys = jax .random .split (
128
128
random_key , num = 2 * len (population ) * n_offspring * population .shape [2 ]
129
- ).reshape (2 , len (population ), n_offspring , population .shape [2 ], 2 )
129
+ ).reshape (2 , len (population ), n_offspring , population .shape [2 ])
130
130
cross_random_key , mutate_random_key = keys
131
131
haploids = _double_haploid (
132
132
population ,
@@ -145,8 +145,8 @@ def double_haploid(
145
145
def _double_haploid (
146
146
individual : Individual ,
147
147
recombination_vec : Float [Array , N_MARKERS ],
148
- cross_random_key : jax .random . PRNGKeyArray ,
149
- mutate_random_key : jax .random . PRNGKeyArray ,
148
+ cross_random_key : jax .Array ,
149
+ mutate_random_key : jax .Array ,
150
150
mutation_probability : float ,
151
151
) -> Haploid :
152
152
return _meiosis (
@@ -165,8 +165,8 @@ def _double_haploid(
165
165
def _meiosis (
166
166
individual : Individual ,
167
167
recombination_vec : Float [Array , N_MARKERS ],
168
- cross_random_key : jax .random . PRNGKeyArray ,
169
- mutate_random_key : jax .random . PRNGKeyArray ,
168
+ cross_random_key : jax .Array ,
169
+ mutate_random_key : jax .Array ,
170
170
mutation_probability : float ,
171
171
) -> Haploid :
172
172
samples = jax .random .uniform (cross_random_key , shape = recombination_vec .shape )
0 commit comments