Skip to content

Commit

Permalink
Merge branch 'priors'
Browse files Browse the repository at this point in the history
  • Loading branch information
ahillsley committed Jan 22, 2024
2 parents 3f6b486 + 8c9dfdf commit cae843b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
6 changes: 3 additions & 3 deletions blinx/fluorescence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def p_x_given_z(
)


def p_norm(x_tilda_left, x_tilda_right, loc, scale):
def p_norm(x_left, x_right, loc, scale):
# implimnetation of the normal distribution

cdf_left = jax.scipy.stats.norm.cdf(x_tilda_left, loc=loc, scale=scale)
cdf_right = jax.scipy.stats.norm.cdf(x_tilda_right, loc=loc, scale=scale)
cdf_left = jax.scipy.stats.norm.cdf(x_left, loc=loc, scale=scale)
cdf_right = jax.scipy.stats.norm.cdf(x_right, loc=loc, scale=scale)

return cdf_right - cdf_left

Expand Down
17 changes: 17 additions & 0 deletions blinx/hyper_parameters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .parameters import Parameters
import jax.numpy as jnp


def create_step_sizes(*args, **kwargs):
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
num_x_bins=1024,
p_outlier=0.1,
delta_t=200.0,
param_min_max_scale=5, # how many sigmas away from mean should the model consider
r_e_loc=None,
r_e_scale=None,
r_bg_loc=None,
Expand All @@ -106,6 +108,9 @@ def __init__(
self.num_x_bins = num_x_bins
self.p_outlier = p_outlier
self.delta_t = delta_t

# priors
self.param_min_max_scale = param_min_max_scale
self.r_e_loc = r_e_loc
self.r_e_scale = r_e_scale
self.r_bg_loc = r_bg_loc
Expand All @@ -127,3 +132,15 @@ def __init__(
raise RuntimeError("Both mu_loc and mu_scale need to be provided")
if sum([sigma_loc is None, sigma_scale is None]) == 1:
raise RuntimeError("Both sigma_loc and sigma_scale need to be provided")

# need to define bin sizes for norm.cdf to get actual probabilities from priors
if r_e_loc is not None:
self.r_e_bin = r_e_loc / num_x_bins
if r_bg_loc is not None:
self.r_bg_bin = r_bg_loc / num_x_bins
if g_loc is not None:
self.g_bin = g_loc / num_x_bins
if mu_loc is not None:
self.mu_bin = mu_loc / num_x_bins
if sigma_loc is not None:
self.sigma_bin = sigma_loc / num_x_bins
38 changes: 25 additions & 13 deletions blinx/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,45 @@ def log_p_parameters(parameters, hyper_parameters):
log_p = 0.0
if hyper_parameters.r_e_loc is not None:
log_p += jnp.log(
norm.pdf(
parameters.r_e, hyper_parameters.r_e_loc, hyper_parameters.r_e_scale
)
p_norm(
parameters.r_e - hyper_parameters.r_e_bin/2,
parameters.r_e + hyper_parameters.r_e_bin/2,
hyper_parameters.r_e_loc,
hyper_parameters.r_e_scale)
)
if hyper_parameters.r_bg_loc is not None:
log_p += jnp.log(
norm.pdf(
parameters.r_bg, hyper_parameters.r_bg_loc, hyper_parameters.r_bg_scale
)
p_norm(
parameters.r_bg - hyper_parameters.r_bg_bin/2,
parameters.r_bg + hyper_parameters.r_bg_bin/2,
hyper_parameters.r_bg_loc,
hyper_parameters.r_bg_scale)
)
if hyper_parameters.g_loc is not None:
log_p += jnp.log(
norm.pdf(parameters.gain, hyper_parameters.g_loc, hyper_parameters.g_scale)
p_norm(
parameters.gain - hyper_parameters.g_bin/2,
parameters.gain + hyper_parameters.g_bin/2,
hyper_parameters.g_loc,
hyper_parameters.g_scale)
)
if hyper_parameters.mu_loc is not None:
log_p += jnp.log(
norm.pdf(parameters.mu_ro, hyper_parameters.mu_loc, hyper_parameters.mu_scale)
p_norm(
parameters.mu_ro - hyper_parameters.mu_bin/2,
parameters.mu_ro + hyper_parameters.mu_bin/2,
hyper_parameters.mu_loc,
hyper_parameters.mu_scale)
)

if hyper_parameters.sigma_loc is not None:
log_p += jnp.log(
norm.pdf(parameters.sigma_ro, hyper_parameters.sigma_loc, hyper_parameters.sigma_scale)
p_norm(
parameters.sigma_ro - hyper_parameters.sigma_bin/2,
parameters.sigma_ro + hyper_parameters.sigma_bin/2,
hyper_parameters.sigma_loc,
hyper_parameters.sigma_scale)
)

# sigma is a uniform prior distribution and will add a constant to all models --> we leave it out
# log_p_sigma = jnp.log(1.0 / (hyper_parameters.sigma_max - hyper_parameters.sigma_min))

# We don't model a uniform prior distribution for p_on and p_off, because with bounds 0-1 it reduces to 0

return log_p
Expand Down

0 comments on commit cae843b

Please sign in to comment.