diff --git a/blinx/fluorescence_model.py b/blinx/fluorescence_model.py index e3011d8..6ba6b15 100644 --- a/blinx/fluorescence_model.py +++ b/blinx/fluorescence_model.py @@ -73,11 +73,11 @@ def p_x_given_z( ) -def p_norm(x_tilda_left, x_tilda_right, loc, scale): +def p_norm(x_left, x_right, loc, scale): # implimnetation of the normal distribution - cdf_left = jax.scipy.stats.norm.cdf(x_tilda_left, loc=loc, scale=scale) - cdf_right = jax.scipy.stats.norm.cdf(x_tilda_right, loc=loc, scale=scale) + cdf_left = jax.scipy.stats.norm.cdf(x_left, loc=loc, scale=scale) + cdf_right = jax.scipy.stats.norm.cdf(x_right, loc=loc, scale=scale) return cdf_right - cdf_left diff --git a/blinx/hyper_parameters.py b/blinx/hyper_parameters.py index b0e99c6..147e17c 100644 --- a/blinx/hyper_parameters.py +++ b/blinx/hyper_parameters.py @@ -1,4 +1,5 @@ from .parameters import Parameters +import jax.numpy as jnp def create_step_sizes(*args, **kwargs): @@ -84,6 +85,7 @@ def __init__( num_x_bins=1024, p_outlier=0.1, delta_t=200.0, + param_min_max_scale=5, # how many sigmas away from mean should the model consider r_e_loc=None, r_e_scale=None, r_bg_loc=None, @@ -106,6 +108,9 @@ def __init__( self.num_x_bins = num_x_bins self.p_outlier = p_outlier self.delta_t = delta_t + + # priors + self.param_min_max_scale = param_min_max_scale self.r_e_loc = r_e_loc self.r_e_scale = r_e_scale self.r_bg_loc = r_bg_loc @@ -127,3 +132,15 @@ def __init__( raise RuntimeError("Both mu_loc and mu_scale need to be provided") if sum([sigma_loc is None, sigma_scale is None]) == 1: raise RuntimeError("Both sigma_loc and sigma_scale need to be provided") + + # need to define bin sizes for norm.cdf to get actual probabilities from priors + if r_e_loc is not None: + self.r_e_bin = r_e_loc / num_x_bins + if r_bg_loc is not None: + self.r_bg_bin = r_bg_loc / num_x_bins + if g_loc is not None: + self.g_bin = g_loc / num_x_bins + if mu_loc is not None: + self.mu_bin = mu_loc / num_x_bins + if sigma_loc is not None: + self.sigma_bin = sigma_loc / num_x_bins \ No newline at end of file diff --git a/blinx/trace_model.py b/blinx/trace_model.py index 348bb60..8093a26 100644 --- a/blinx/trace_model.py +++ b/blinx/trace_model.py @@ -21,33 +21,45 @@ def log_p_parameters(parameters, hyper_parameters): log_p = 0.0 if hyper_parameters.r_e_loc is not None: log_p += jnp.log( - norm.pdf( - parameters.r_e, hyper_parameters.r_e_loc, hyper_parameters.r_e_scale - ) + p_norm( + parameters.r_e - hyper_parameters.r_e_bin/2, + parameters.r_e + hyper_parameters.r_e_bin/2, + hyper_parameters.r_e_loc, + hyper_parameters.r_e_scale) ) if hyper_parameters.r_bg_loc is not None: log_p += jnp.log( - norm.pdf( - parameters.r_bg, hyper_parameters.r_bg_loc, hyper_parameters.r_bg_scale - ) + p_norm( + parameters.r_bg - hyper_parameters.r_bg_bin/2, + parameters.r_bg + hyper_parameters.r_bg_bin/2, + hyper_parameters.r_bg_loc, + hyper_parameters.r_bg_scale) ) if hyper_parameters.g_loc is not None: log_p += jnp.log( - norm.pdf(parameters.gain, hyper_parameters.g_loc, hyper_parameters.g_scale) + p_norm( + parameters.gain - hyper_parameters.g_bin/2, + parameters.gain + hyper_parameters.g_bin/2, + hyper_parameters.g_loc, + hyper_parameters.g_scale) ) if hyper_parameters.mu_loc is not None: log_p += jnp.log( - norm.pdf(parameters.mu_ro, hyper_parameters.mu_loc, hyper_parameters.mu_scale) + p_norm( + parameters.mu_ro - hyper_parameters.mu_bin/2, + parameters.mu_ro + hyper_parameters.mu_bin/2, + hyper_parameters.mu_loc, + hyper_parameters.mu_scale) ) - if hyper_parameters.sigma_loc is not None: log_p += jnp.log( - norm.pdf(parameters.sigma_ro, hyper_parameters.sigma_loc, hyper_parameters.sigma_scale) + p_norm( + parameters.sigma_ro - hyper_parameters.sigma_bin/2, + parameters.sigma_ro + hyper_parameters.sigma_bin/2, + hyper_parameters.sigma_loc, + hyper_parameters.sigma_scale) ) - # sigma is a uniform prior distribution and will add a constant to all models --> we leave it out - # log_p_sigma = jnp.log(1.0 / (hyper_parameters.sigma_max - hyper_parameters.sigma_min)) - # We don't model a uniform prior distribution for p_on and p_off, because with bounds 0-1 it reduces to 0 return log_p