Skip to content

Commit 022882c

Browse files
committed
fix pre-commit
1 parent 4922b71 commit 022882c

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

chromax/functional.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,19 @@ def cross(
4949
(50, 1000, 2)
5050
"""
5151
parents = parents.reshape(*parents.shape[:3], -1, 2)
52-
random_keys = jax.random.split(random_key, num=2 * len(parents) * 2 * parents.shape[3])
52+
random_keys = jax.random.split(
53+
random_key, num=2 * len(parents) * 2 * parents.shape[3]
54+
)
5355
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3], 2)
5456
cross_random_key, mutate_random_key = random_keys
5557

56-
offsprings = _cross(parents, recombination_vec, cross_random_key, mutate_random_key, mutation_probability)
58+
offsprings = _cross(
59+
parents,
60+
recombination_vec,
61+
cross_random_key,
62+
mutate_random_key,
63+
mutation_probability,
64+
)
5765
return offsprings.reshape(*offsprings.shape[:-2], -1)
5866

5967

@@ -67,7 +75,13 @@ def _cross(
6775
mutate_random_key: jax.random.PRNGKeyArray,
6876
mutation_probability: float,
6977
) -> Haploid:
70-
return _meiosis(parent, recombination_vec, cross_random_key, mutate_random_key, mutation_probability)
78+
return _meiosis(
79+
parent,
80+
recombination_vec,
81+
cross_random_key,
82+
mutate_random_key,
83+
mutation_probability,
84+
)
7185

7286

7387
def double_haploid(
@@ -109,13 +123,18 @@ def double_haploid(
109123
>>> dh.shape
110124
(50, 10, 1000, 2)
111125
"""
112-
113126
population = population.reshape(*population.shape[:2], -1, 2)
114127
keys = jax.random.split(
115128
random_key, num=2 * len(population) * n_offspring * population.shape[2]
116129
).reshape(2, len(population), n_offspring, population.shape[2], 2)
117130
cross_random_key, mutate_random_key = keys
118-
haploids = _double_haploid(population, recombination_vec, cross_random_key, mutate_random_key, mutation_probability)
131+
haploids = _double_haploid(
132+
population,
133+
recombination_vec,
134+
cross_random_key,
135+
mutate_random_key,
136+
mutation_probability,
137+
)
119138
dh_pop = jnp.broadcast_to(haploids[..., None], shape=(*haploids.shape, 2))
120139
return dh_pop.reshape(*dh_pop.shape[:-2], -1)
121140

@@ -130,11 +149,19 @@ def _double_haploid(
130149
mutate_random_key: jax.random.PRNGKeyArray,
131150
mutation_probability: float,
132151
) -> Haploid:
133-
return _meiosis(individual, recombination_vec, cross_random_key, mutate_random_key, mutation_probability)
152+
return _meiosis(
153+
individual,
154+
recombination_vec,
155+
cross_random_key,
156+
mutate_random_key,
157+
mutation_probability,
158+
)
134159

135160

136161
@jax.jit
137-
@partial(jax.vmap, in_axes=(1, None, 0, 0, None), out_axes=1) # parallelize pair of chromosomes
162+
@partial(
163+
jax.vmap, in_axes=(1, None, 0, 0, None), out_axes=1
164+
) # parallelize pair of chromosomes
138165
def _meiosis(
139166
individual: Individual,
140167
recombination_vec: Float[Array, N_MARKERS],

chromax/simulator.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def cross(self, parents: Parents["n"]) -> Population["n"]:
239239
(3, 9839, 2)
240240
"""
241241
self.random_key, split_key = jax.random.split(self.random_key)
242-
return functional.cross(parents, self.recombination_vec, split_key, self.mutation_probability)
242+
return functional.cross(
243+
parents, self.recombination_vec, split_key, self.mutation_probability
244+
)
243245

244246
@property
245247
def differentiable_cross_func(self) -> Callable:
@@ -319,7 +321,11 @@ def double_haploid(
319321
"""
320322
self.random_key, split_key = jax.random.split(self.random_key)
321323
dh = functional.double_haploid(
322-
population, n_offspring, self.recombination_vec, split_key, self.mutation_probability
324+
population,
325+
n_offspring,
326+
self.recombination_vec,
327+
split_key,
328+
self.mutation_probability,
323329
)
324330

325331
if n_offspring == 1:

0 commit comments

Comments
 (0)