From 74ff46ed6e8582cf3fb8054f2a8f4d1178f9f77b Mon Sep 17 00:00:00 2001 From: Yogesh Thambidurai Date: Thu, 20 Jun 2024 11:34:11 -0400 Subject: [PATCH] add vmapped version of generate_trace --- blinx/trace_model.py | 53 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/blinx/trace_model.py b/blinx/trace_model.py index dddeeb4..ccc60a9 100644 --- a/blinx/trace_model.py +++ b/blinx/trace_model.py @@ -30,9 +30,13 @@ def log_p_parameters(parameters, locs, scales): if locs.sigma_ro is not None: log_p += jnp.log(norm.pdf(parameters.sigma_ro, locs.sigma_ro, scales.sigma_ro)) if locs._p_on_logit is not None: - log_p += jnp.log(norm.pdf(parameters.p_on, locs._p_on_logit, scales._p_on_logit)) + log_p += jnp.log( + norm.pdf(parameters.p_on, locs._p_on_logit, scales._p_on_logit) + ) if locs._p_off_logit is not None: - log_p += jnp.log(norm.pdf(parameters.p_off, locs._p_off_logit, scales._p_off_logit)) + log_p += jnp.log( + norm.pdf(parameters.p_off, locs._p_off_logit, scales._p_off_logit) + ) return log_p @@ -281,6 +285,51 @@ def sample_next_z(z, p_transition, key): return z +def vmap_generate_trace( + num_traces, y, parameters, num_frames, hyper_parameters, seed=None +): + """Create several simulated intensity traces. + + Args: + num_traces (int): + - the number of traces to simulate + + y (int): + - the total number of fluorescent emitters + + parameters (:class:'Parameters'): + - the parameters of the fluoresent and trace model + + num_frames (int): + - the number of observations to simulate + + hyper_parameters (:class:`HyperParameters`): + - hypxer-parameters with `delta_t` set for the time between frames in the traces + + seed (int, optional): + - random seed for the jax psudo rendom number generator + + Returns: + trace (array): + - a num_traces x num_frames array containing traces with intensity values for each frame + + states (array): + - array the same shape as trace, containing the number of 'on' emitters in each frame + """ + + if seed is None: + seed = time.time_ns() + key = random.PRNGKey(seed) + subkeys = random.split(key, num_traces) + seeds = subkeys[:, 0] + mapped = jax.vmap( + generate_trace, + in_axes=(None, None, None, None, 0), + ) + trace, zs = mapped(y, parameters, num_frames, hyper_parameters, seeds) + return jnp.squeeze(trace), jnp.squeeze(zs) + + def create_transition_matrix(y, p_on, p_off): """Create a transition matrix for the number of active elements, given that elements can randomly turn on and off.