@@ -49,11 +49,19 @@ def cross(
49
49
(50, 1000, 2)
50
50
"""
51
51
parents = parents .reshape (* parents .shape [:3 ], - 1 , 2 )
52
- random_keys = jax .random .split (random_key , num = 2 * len (parents ) * 2 * parents .shape [3 ])
52
+ random_keys = jax .random .split (
53
+ random_key , num = 2 * len (parents ) * 2 * parents .shape [3 ]
54
+ )
53
55
random_keys = random_keys .reshape (2 , len (parents ), 2 , parents .shape [3 ], 2 )
54
56
cross_random_key , mutate_random_key = random_keys
55
57
56
- offsprings = _cross (parents , recombination_vec , cross_random_key , mutate_random_key , mutation_probability )
58
+ offsprings = _cross (
59
+ parents ,
60
+ recombination_vec ,
61
+ cross_random_key ,
62
+ mutate_random_key ,
63
+ mutation_probability ,
64
+ )
57
65
return offsprings .reshape (* offsprings .shape [:- 2 ], - 1 )
58
66
59
67
@@ -67,7 +75,13 @@ def _cross(
67
75
mutate_random_key : jax .random .PRNGKeyArray ,
68
76
mutation_probability : float ,
69
77
) -> Haploid :
70
- return _meiosis (parent , recombination_vec , cross_random_key , mutate_random_key , mutation_probability )
78
+ return _meiosis (
79
+ parent ,
80
+ recombination_vec ,
81
+ cross_random_key ,
82
+ mutate_random_key ,
83
+ mutation_probability ,
84
+ )
71
85
72
86
73
87
def double_haploid (
@@ -109,13 +123,18 @@ def double_haploid(
109
123
>>> dh.shape
110
124
(50, 10, 1000, 2)
111
125
"""
112
-
113
126
population = population .reshape (* population .shape [:2 ], - 1 , 2 )
114
127
keys = jax .random .split (
115
128
random_key , num = 2 * len (population ) * n_offspring * population .shape [2 ]
116
129
).reshape (2 , len (population ), n_offspring , population .shape [2 ], 2 )
117
130
cross_random_key , mutate_random_key = keys
118
- haploids = _double_haploid (population , recombination_vec , cross_random_key , mutate_random_key , mutation_probability )
131
+ haploids = _double_haploid (
132
+ population ,
133
+ recombination_vec ,
134
+ cross_random_key ,
135
+ mutate_random_key ,
136
+ mutation_probability ,
137
+ )
119
138
dh_pop = jnp .broadcast_to (haploids [..., None ], shape = (* haploids .shape , 2 ))
120
139
return dh_pop .reshape (* dh_pop .shape [:- 2 ], - 1 )
121
140
@@ -130,11 +149,19 @@ def _double_haploid(
130
149
mutate_random_key : jax .random .PRNGKeyArray ,
131
150
mutation_probability : float ,
132
151
) -> Haploid :
133
- return _meiosis (individual , recombination_vec , cross_random_key , mutate_random_key , mutation_probability )
152
+ return _meiosis (
153
+ individual ,
154
+ recombination_vec ,
155
+ cross_random_key ,
156
+ mutate_random_key ,
157
+ mutation_probability ,
158
+ )
134
159
135
160
136
161
@jax .jit
137
- @partial (jax .vmap , in_axes = (1 , None , 0 , 0 , None ), out_axes = 1 ) # parallelize pair of chromosomes
162
+ @partial (
163
+ jax .vmap , in_axes = (1 , None , 0 , 0 , None ), out_axes = 1
164
+ ) # parallelize pair of chromosomes
138
165
def _meiosis (
139
166
individual : Individual ,
140
167
recombination_vec : Float [Array , N_MARKERS ],
0 commit comments