Skip to content

Commit 6b5a26f

Browse files
committed
switch PRNGKey to key
1 parent 2d624f8 commit 6b5a26f

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

chromax/functional.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def cross(
2525
:param recombination_vec: array of m probabilities.
2626
The i-th value represent the probability to recombine before the marker i.
2727
:type recombination_vec:
28-
:param random_key: JAX Array with dtype jax.dtypes.prng_key, for reproducibility purpose.
28+
:param random_key: JAX random key, for reproducibility purpose.
2929
:type random_key: jax.Array
3030
:param mutation_probability: The probability of having a mutation in a marker.
3131
:type mutation_probability: float
@@ -43,7 +43,7 @@ def cross(
4343
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
4444
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
4545
>>> rec_vec = rec_vec.flatten()
46-
>>> random_key = jax.random.PRNGKey(42)
46+
>>> random_key = jax.random.key(42)
4747
>>> f2 = functional.cross(parents, rec_vec, random_key)
4848
>>> f2.shape
4949
(50, 1000, 2)
@@ -52,7 +52,7 @@ def cross(
5252
random_keys = jax.random.split(
5353
random_key, num=2 * len(parents) * 2 * parents.shape[3]
5454
)
55-
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3], 2)
55+
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3])
5656
cross_random_key, mutate_random_key = random_keys
5757

5858
offsprings = _cross(
@@ -100,8 +100,8 @@ def double_haploid(
100100
:param recombination_vec: array of m probabilities.
101101
The i-th value represent the probability to recombine before the marker i.
102102
:type recombination_vec: ndarray
103-
:param random_key: array of n PRNGKey, one for each individual.
104-
:type random_key: jax.Array with dtype jax.dtypes.prng_key
103+
:param random_key: JAX random key, for reproducibility purpose.
104+
:type random_key: jax.Array
105105
:param mutation_probability: The probability of having a mutation in a marker.
106106
:type mutation_probability: float
107107
:return: output population of shape (n, n_offspring, m, d).
@@ -118,15 +118,15 @@ def double_haploid(
118118
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
119119
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
120120
>>> rec_vec = rec_vec.flatten()
121-
>>> random_key = jax.random.PRNGKey(42)
121+
>>> random_key = jax.random.key(42)
122122
>>> dh = functional.double_haploid(f1, 10, rec_vec, random_key)
123123
>>> dh.shape
124124
(50, 10, 1000, 2)
125125
"""
126126
population = population.reshape(*population.shape[:2], -1, 2)
127127
keys = jax.random.split(
128128
random_key, num=2 * len(population) * n_offspring * population.shape[2]
129-
).reshape(2, len(population), n_offspring, population.shape[2], 2)
129+
).reshape(2, len(population), n_offspring, population.shape[2])
130130
cross_random_key, mutate_random_key = keys
131131
haploids = _double_haploid(
132132
population,

chromax/simulator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def set_seed(self, seed: int):
171171
:param seed: random seed.
172172
:type seed: int
173173
"""
174-
self.random_key = jax.random.PRNGKey(seed)
174+
self.random_key = jax.random.key(seed)
175175

176176
def load_population(self, file_name: Union[Path, str]) -> Population["n"]:
177177
"""Load a population from file.
@@ -271,7 +271,7 @@ def differentiable_cross_func(self) -> Callable:
271271
>>> f1 = simulator.load_population(sample_data.genome)
272272
>>> weights = np.random.uniform(size=(10, len(f1), 2))
273273
>>> weights /= weights.sum(axis=1, keepdims=True)
274-
>>> random_key = jax.random.PRNGKey(42)
274+
>>> random_key = jax.random.key(42)
275275
>>> grad_value = grad_f(f1, weights, random_key)
276276
>>> grad_value.shape
277277
(10, 371, 2)
@@ -284,7 +284,7 @@ def differentiable_cross_func(self) -> Callable:
284284
def diff_cross_f(
285285
population: Population["n"],
286286
cross_weights: Float[Array, "m n 2"],
287-
random_key: jax.random.PRNGKeyArray,
287+
random_key: jax.Array,
288288
) -> Population["m"]:
289289
population = population.reshape(*population.shape[:-1], -1, 2)
290290
keys_shape = len(cross_weights), len(population), 2, population.shape[-2]

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
"Programming Language :: Python :: 3.11",
2222
'Intended Audience :: Science/Research',
2323
]
24-
dependencies = ["numpy", "pandas", "jax", "jaxlib", "jaxtyping"]
24+
dependencies = ["numpy", "pandas", "jax>=0.4.16", "jaxlib>=0.4.16", "jaxtyping"]
2525
dynamic = ["version"]
2626

2727
[project.optional-dependencies]

tests/test_functional.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_cross(idx):
1515
parents = np.random.choice([False, True], size=parents_shape)
1616
rec_vec = np.zeros(n_markers)
1717
rec_vec[0] = idx
18-
random_key = jax.random.PRNGKey(42)
18+
random_key = jax.random.key(42)
1919
new_pop = functional.cross(parents, rec_vec, random_key)
2020

2121
for i in range(ploidy):
@@ -28,7 +28,7 @@ def test_double_haploid():
2828
pop_shape = (50, n_chr * chr_len, ploidy)
2929
f1 = np.random.choice([False, True], size=pop_shape)
3030
rec_vec = np.full((n_chr * chr_len,), 1.5 / chr_len)
31-
random_key = jax.random.PRNGKey(42)
31+
random_key = jax.random.key(42)
3232
dh = functional.double_haploid(f1, n_offspring, rec_vec, random_key)
3333
assert dh.shape == (len(f1), n_offspring, n_chr * chr_len, ploidy)
3434

@@ -62,7 +62,7 @@ def test_cross_mutation():
6262
rec_vec = np.full((n_markers,), 1.5e-2)
6363
cross = functional.cross
6464

65-
random_key = jax.random.PRNGKey(42)
65+
random_key = jax.random.key(42)
6666
assert np.all(cross(zeros_pop, rec_vec, random_key) == 0)
6767
assert np.all(cross(zeros_pop, rec_vec, random_key, 1) == 1)
6868
mutated_pop = cross(zeros_pop, rec_vec, random_key, 0.5)
@@ -83,7 +83,7 @@ def test_dh_mutation():
8383
rec_vec = np.full((n_markers,), 1.5e-2)
8484
dh = functional.double_haploid
8585

86-
random_key = jax.random.PRNGKey(42)
86+
random_key = jax.random.key(42)
8787
assert np.all(dh(zeros_pop, 10, rec_vec, random_key) == 0)
8888
assert np.all(dh(zeros_pop, 10, rec_vec, random_key, 1) == 1)
8989
mutated_pop = dh(zeros_pop, 10, rec_vec, random_key, 0.5)

0 commit comments

Comments
 (0)