Skip to content

add vmapped version of generate_trace #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions blinx/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def log_p_parameters(parameters, locs, scales):
if locs.sigma_ro is not None:
log_p += jnp.log(norm.pdf(parameters.sigma_ro, locs.sigma_ro, scales.sigma_ro))
if locs._p_on_logit is not None:
log_p += jnp.log(norm.pdf(parameters.p_on, locs._p_on_logit, scales._p_on_logit))
log_p += jnp.log(
norm.pdf(parameters.p_on, locs._p_on_logit, scales._p_on_logit)
)
if locs._p_off_logit is not None:
log_p += jnp.log(norm.pdf(parameters.p_off, locs._p_off_logit, scales._p_off_logit))
log_p += jnp.log(
norm.pdf(parameters.p_off, locs._p_off_logit, scales._p_off_logit)
)

return log_p

Expand Down Expand Up @@ -281,6 +285,51 @@ def sample_next_z(z, p_transition, key):
return z


def vmap_generate_trace(
num_traces, y, parameters, num_frames, hyper_parameters, seed=None
):
"""Create several simulated intensity traces.

Args:
num_traces (int):
- the number of traces to simulate

y (int):
- the total number of fluorescent emitters

parameters (:class:'Parameters'):
- the parameters of the fluoresent and trace model

num_frames (int):
- the number of observations to simulate

hyper_parameters (:class:`HyperParameters`):
- hypxer-parameters with `delta_t` set for the time between frames in the traces

seed (int, optional):
- random seed for the jax psudo rendom number generator

Returns:
trace (array):
- a num_traces x num_frames array containing traces with intensity values for each frame

states (array):
- array the same shape as trace, containing the number of 'on' emitters in each frame
"""

if seed is None:
seed = time.time_ns()
key = random.PRNGKey(seed)
subkeys = random.split(key, num_traces)
seeds = subkeys[:, 0]
mapped = jax.vmap(
generate_trace,
in_axes=(None, None, None, None, 0),
)
trace, zs = mapped(y, parameters, num_frames, hyper_parameters, seeds)
return jnp.squeeze(trace), jnp.squeeze(zs)


def create_transition_matrix(y, p_on, p_off):
"""Create a transition matrix for the number of active elements, given that
elements can randomly turn on and off.
Expand Down
Loading