-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mutation #13
Mutation #13
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for doing this!
I left some comments. I have time now, so if you need any help just let me know!
chromax/simulator.py
Outdated
if mutation is None: | ||
self.mutation = 0.0 | ||
elif mutation > 0 and mutation < 1: | ||
self.mutation = mutation | ||
else: | ||
raise ValueError( | ||
f"mutation must be between 0 and 1, but got {mutation}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To simplify this, you can use as default 0 in init definition mutation: float = 0f
Then here, something like this:
if mutation < 0 or mutation > 0:
raise ValueError(
f"mutation must be between 0 and 1, but got {mutation}"
)
self.mutation = mutation
chromax/simulator.py
Outdated
print(f'shape is {self.random_key.shape} {self.random_key[0]}') | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove
chromax/functional.py
Outdated
rec_sites = samples < recombination_vec | ||
crossover_mask = jax.lax.associative_scan(jnp.logical_xor, rec_sites) | ||
|
||
crossover_mask = crossover_mask.astype(jnp.int8) | ||
haploid = jnp.take_along_axis(individual, crossover_mask[:, None], axis=-1) | ||
|
||
mutation_samples = jax.random.uniform(mutate_random_key, shape=haploid.shape) | ||
mutation_sites = mutation_samples > mutate_probability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be mutation_samples < mutate_probability
, isn't it?
If mutate_probability = 0
, mutation_sites
should always be False
chromax/functional.py
Outdated
random_keys = jax.random.split(random_key, num=len(parents) * 2 * parents.shape[3]) | ||
random_keys = random_keys.reshape(len(parents), 2, parents.shape[3], 2) | ||
offsprings = _cross(parents, recombination_vec, random_keys) | ||
cross_random_key = jax.random.split(cross_random_key, num=len(parents) * 2 * parents.shape[3]) | ||
cross_random_key = cross_random_key.reshape(len(parents), 2, parents.shape[3], 2) | ||
|
||
mutate_split_key = jax.random.split(mutate_split_key, num=len(parents) * 2 * parents.shape[3]) | ||
mutate_split_key = mutate_split_key.reshape(len(parents), 2, parents.shape[3], 2) | ||
|
||
offsprings = _cross(parents, recombination_vec, cross_random_key, mutate_split_key, mutate_probability) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would avoid adding the other keys as argument, but I would generate them internally here.
So you can do:
random_keys = jax.random.split(random_key, num=2 * len(parents) * 2 * parents.shape[3])
random_keys = random_keys.reshape(3, len(parents), 2, parents.shape[3], 2)
cross_random_key, mutate_random_key = random_keys
Also, use default value of 0 for mutate_probability, so previous code will continue working
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random_keys = random_keys.reshape(3, len(parents), 2, parents.shape[3], 2)
you mean
random_keys = random_keys.reshape(2, len(parents), 2, parents.shape[3], 2)?
chromax/functional.py
Outdated
cross_random_key: jax.random.PRNGKeyArray, | ||
mutate_random_key: jax.random.PRNGKeyArray, | ||
mutate_probability: float, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as cross
have a look new change |
Merged, thank you! @oddoking |
add mutation random key and mutate prob