Skip to content

Commit 7d24a2e

Browse files
jeffhsu3younik
andauthored
fixes PRNGKey deprecation (#15)
* fixes #14 PRNGKey Array dtype * switch PRNGKey to key * drop 3.8 support * fix lint --------- Co-authored-by: Omar Younis <omar.younis98@gmail.com>
1 parent 450ca24 commit 7d24a2e

File tree

5 files changed

+26
-27
lines changed

5 files changed

+26
-27
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
runs-on: ubuntu-latest
1616
strategy:
1717
matrix:
18-
python-version: ['3.8', '3.9', '3.10', '3.11']
18+
python-version: ['3.9', '3.10', '3.11', '3.12']
1919
steps:
2020
- uses: actions/checkout@v3
2121
- name: Set up Python ${{ matrix.python-version }}

chromax/functional.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def cross(
1414
parents: Parents["n"],
1515
recombination_vec: Float[Array, N_MARKERS],
16-
random_key: jax.random.PRNGKeyArray,
16+
random_key: jax.Array,
1717
mutation_probability: float = 0.0,
1818
) -> Population["n"]:
1919
"""Main function that computes crosses from a list of parents.
@@ -25,8 +25,8 @@ def cross(
2525
:param recombination_vec: array of m probabilities.
2626
The i-th value represent the probability to recombine before the marker i.
2727
:type recombination_vec:
28-
:param random_key: JAX PRNGKey, for reproducibility purpose.
29-
:type random_key: jax.random.PRNGKeyArray
28+
:param random_key: JAX random key, for reproducibility purpose.
29+
:type random_key: jax.Array
3030
:param mutation_probability: The probability of having a mutation in a marker.
3131
:type mutation_probability: float
3232
:return: offspring population of shape (n, m, d).
@@ -43,7 +43,7 @@ def cross(
4343
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
4444
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
4545
>>> rec_vec = rec_vec.flatten()
46-
>>> random_key = jax.random.PRNGKey(42)
46+
>>> random_key = jax.random.key(42)
4747
>>> f2 = functional.cross(parents, rec_vec, random_key)
4848
>>> f2.shape
4949
(50, 1000, 2)
@@ -52,7 +52,7 @@ def cross(
5252
random_keys = jax.random.split(
5353
random_key, num=2 * len(parents) * 2 * parents.shape[3]
5454
)
55-
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3], 2)
55+
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3])
5656
cross_random_key, mutate_random_key = random_keys
5757

5858
offsprings = _cross(
@@ -71,8 +71,8 @@ def cross(
7171
def _cross(
7272
parent: Individual,
7373
recombination_vec: Float[Array, N_MARKERS],
74-
cross_random_key: jax.random.PRNGKeyArray,
75-
mutate_random_key: jax.random.PRNGKeyArray,
74+
cross_random_key: jax.Array,
75+
mutate_random_key: jax.Array,
7676
mutation_probability: float,
7777
) -> Haploid:
7878
return _meiosis(
@@ -88,7 +88,7 @@ def double_haploid(
8888
population: Population["n"],
8989
n_offspring: int,
9090
recombination_vec: Float[Array, N_MARKERS],
91-
random_key: jax.random.PRNGKeyArray,
91+
random_key: jax.Array,
9292
mutation_probability: float = 0.0,
9393
) -> Population["n n_offspring"]:
9494
"""Computes the double haploid of the input population.
@@ -100,8 +100,8 @@ def double_haploid(
100100
:param recombination_vec: array of m probabilities.
101101
The i-th value represent the probability to recombine before the marker i.
102102
:type recombination_vec: ndarray
103-
:param random_key: array of n PRNGKey, one for each individual.
104-
:type random_key: jax.random.PRNGKeyArray
103+
:param random_key: JAX random key, for reproducibility purpose.
104+
:type random_key: jax.Array
105105
:param mutation_probability: The probability of having a mutation in a marker.
106106
:type mutation_probability: float
107107
:return: output population of shape (n, n_offspring, m, d).
@@ -118,15 +118,15 @@ def double_haploid(
118118
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
119119
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
120120
>>> rec_vec = rec_vec.flatten()
121-
>>> random_key = jax.random.PRNGKey(42)
121+
>>> random_key = jax.random.key(42)
122122
>>> dh = functional.double_haploid(f1, 10, rec_vec, random_key)
123123
>>> dh.shape
124124
(50, 10, 1000, 2)
125125
"""
126126
population = population.reshape(*population.shape[:2], -1, 2)
127127
keys = jax.random.split(
128128
random_key, num=2 * len(population) * n_offspring * population.shape[2]
129-
).reshape(2, len(population), n_offspring, population.shape[2], 2)
129+
).reshape(2, len(population), n_offspring, population.shape[2])
130130
cross_random_key, mutate_random_key = keys
131131
haploids = _double_haploid(
132132
population,
@@ -145,8 +145,8 @@ def double_haploid(
145145
def _double_haploid(
146146
individual: Individual,
147147
recombination_vec: Float[Array, N_MARKERS],
148-
cross_random_key: jax.random.PRNGKeyArray,
149-
mutate_random_key: jax.random.PRNGKeyArray,
148+
cross_random_key: jax.Array,
149+
mutate_random_key: jax.Array,
150150
mutation_probability: float,
151151
) -> Haploid:
152152
return _meiosis(
@@ -165,8 +165,8 @@ def _double_haploid(
165165
def _meiosis(
166166
individual: Individual,
167167
recombination_vec: Float[Array, N_MARKERS],
168-
cross_random_key: jax.random.PRNGKeyArray,
169-
mutate_random_key: jax.random.PRNGKeyArray,
168+
cross_random_key: jax.Array,
169+
mutate_random_key: jax.Array,
170170
mutation_probability: float,
171171
) -> Haploid:
172172
samples = jax.random.uniform(cross_random_key, shape=recombination_vec.shape)

chromax/simulator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def set_seed(self, seed: int):
171171
:param seed: random seed.
172172
:type seed: int
173173
"""
174-
self.random_key = jax.random.PRNGKey(seed)
174+
self.random_key = jax.random.key(seed)
175175

176176
def load_population(self, file_name: Union[Path, str]) -> Population["n"]:
177177
"""Load a population from file.
@@ -271,7 +271,7 @@ def differentiable_cross_func(self) -> Callable:
271271
>>> f1 = simulator.load_population(sample_data.genome)
272272
>>> weights = np.random.uniform(size=(10, len(f1), 2))
273273
>>> weights /= weights.sum(axis=1, keepdims=True)
274-
>>> random_key = jax.random.PRNGKey(42)
274+
>>> random_key = jax.random.key(42)
275275
>>> grad_value = grad_f(f1, weights, random_key)
276276
>>> grad_value.shape
277277
(10, 371, 2)
@@ -284,7 +284,7 @@ def differentiable_cross_func(self) -> Callable:
284284
def diff_cross_f(
285285
population: Population["n"],
286286
cross_weights: Float[Array, "m n 2"],
287-
random_key: jax.random.PRNGKeyArray,
287+
random_key: jax.Array,
288288
) -> Population["m"]:
289289
population = population.reshape(*population.shape[:-1], -1, 2)
290290
keys_shape = len(cross_weights), len(population), 2, population.shape[-2]

pyproject.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,19 @@ build-backend = "setuptools.build_meta"
88
name = "chromax"
99
description = "Breeding simulator based on JAX"
1010
readme = "README.md"
11-
requires-python = ">= 3.8"
11+
requires-python = ">= 3.9"
1212
authors = [{ name = "Omar G. Younis", email = "omar.younis98@gmail.com" }]
1313
license = { text = "BSD-3-Clause" }
1414
keywords = ["Breeding", "simulator", "JAX", "chromosome", "genetics", "bioinformatics",]
1515
classifiers = [
1616
"Development Status :: 4 - Beta",
1717
"Programming Language :: Python :: 3",
18-
"Programming Language :: Python :: 3.8",
1918
"Programming Language :: Python :: 3.9",
2019
"Programming Language :: Python :: 3.10",
2120
"Programming Language :: Python :: 3.11",
2221
'Intended Audience :: Science/Research',
2322
]
24-
dependencies = ["numpy", "pandas", "jax", "jaxlib", "jaxtyping"]
23+
dependencies = ["numpy", "pandas", "jax>=0.4.16", "jaxlib>=0.4.16", "jaxtyping"]
2524
dynamic = ["version"]
2625

2726
[project.optional-dependencies]

tests/test_functional.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_cross(idx):
1515
parents = np.random.choice([False, True], size=parents_shape)
1616
rec_vec = np.zeros(n_markers)
1717
rec_vec[0] = idx
18-
random_key = jax.random.PRNGKey(42)
18+
random_key = jax.random.key(42)
1919
new_pop = functional.cross(parents, rec_vec, random_key)
2020

2121
for i in range(ploidy):
@@ -28,7 +28,7 @@ def test_double_haploid():
2828
pop_shape = (50, n_chr * chr_len, ploidy)
2929
f1 = np.random.choice([False, True], size=pop_shape)
3030
rec_vec = np.full((n_chr * chr_len,), 1.5 / chr_len)
31-
random_key = jax.random.PRNGKey(42)
31+
random_key = jax.random.key(42)
3232
dh = functional.double_haploid(f1, n_offspring, rec_vec, random_key)
3333
assert dh.shape == (len(f1), n_offspring, n_chr * chr_len, ploidy)
3434

@@ -62,7 +62,7 @@ def test_cross_mutation():
6262
rec_vec = np.full((n_markers,), 1.5e-2)
6363
cross = functional.cross
6464

65-
random_key = jax.random.PRNGKey(42)
65+
random_key = jax.random.key(42)
6666
assert np.all(cross(zeros_pop, rec_vec, random_key) == 0)
6767
assert np.all(cross(zeros_pop, rec_vec, random_key, 1) == 1)
6868
mutated_pop = cross(zeros_pop, rec_vec, random_key, 0.5)
@@ -83,7 +83,7 @@ def test_dh_mutation():
8383
rec_vec = np.full((n_markers,), 1.5e-2)
8484
dh = functional.double_haploid
8585

86-
random_key = jax.random.PRNGKey(42)
86+
random_key = jax.random.key(42)
8787
assert np.all(dh(zeros_pop, 10, rec_vec, random_key) == 0)
8888
assert np.all(dh(zeros_pop, 10, rec_vec, random_key, 1) == 1)
8989
mutated_pop = dh(zeros_pop, 10, rec_vec, random_key, 0.5)

0 commit comments

Comments
 (0)