Skip to content

Commit 8099615

Browse files
committed
Merge branch 'new_intensity'
2 parents 9fbc5f3 + 385f8ee commit 8099615

File tree

3 files changed

+23
-51
lines changed

3 files changed

+23
-51
lines changed

blinx/estimate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# FIXME: post_process should be renamed and find a new home
1313
from .post_process import post_process as find_most_likely_y
1414
from .trace_model import get_trace_log_likelihood
15-
from .utils import find_local_maxima
15+
from .utils import find_maximum
1616

1717

1818
def estimate_y(traces, max_y, parameter_ranges=None, hyper_parameters=None):
@@ -246,9 +246,9 @@ def get_initial_parameter_guesses(traces, y, parameter_ranges, hyper_parameters)
246246
)
247247

248248
# find locations where parameters maximize log likelihoods
249-
min_indices = find_local_maxima(trace_log_likelihoods, num_guesses)
249+
min_index = find_maximum(trace_log_likelihoods)
250250

251-
guesses.append(parameters[min_indices])
251+
guesses.append(parameters[min_index])
252252

253253
# all guesses are stored in 'guesses', the following stacks them together
254254
# as if we vmap'ed over traces:

blinx/trace_model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ def single_optimal_trace(trace, y, parameters, hyper_parameters):
105105
106106
"""
107107

108+
r_e = parameters.r_e
109+
r_bg = parameters.r_bg
110+
mu_ro = parameters.mu_ro
111+
sigma_ro = parameters.sigma_ro
112+
gain = parameters.gain
113+
p_on = parameters.p_on
114+
p_off = parameters.p_off
115+
108116
zs = jnp.arange(0, y + 1)
109117

110118
# Discretize the trace into bins
@@ -116,20 +124,12 @@ def single_optimal_trace(trace, y, parameters, hyper_parameters):
116124
x_left = (trace // bin_width) * bin_width
117125
x_right = x_left + bin_width
118126

119-
p_transition = create_transition_matrix(y, parameters.p_on, parameters.p_off)
127+
p_transition = create_transition_matrix(y, p_on, p_off)
120128
p_initial = get_steady_state(p_transition)
121129
p_measurement = jax.vmap(
122130
p_x_given_z,
123-
in_axes=(None, None, 0, None, None, None, None),
124-
)(
125-
x_left,
126-
x_right,
127-
zs,
128-
parameters.mu,
129-
parameters.mu_bg,
130-
parameters.sigma,
131-
hyper_parameters,
132-
)
131+
in_axes=(None, None, 0, None, None, None, None, None, None),
132+
)(x_left, x_right, zs, r_e, r_bg, mu_ro, sigma_ro, gain, hyper_parameters)
133133

134134
return get_optimal_states(p_measurement, p_initial, p_transition)
135135

blinx/utils.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,18 @@
33
import scipy.signal
44

55

6-
def find_local_maxima(matrix, num_maxima=None):
7-
# convert to numpy array
8-
matrix = np.asarray(matrix)
6+
def find_maximum(matrix):
7+
temp_matrix = np.array(matrix)
98

10-
if num_maxima is None:
11-
num_maxima = matrix.size
9+
# argmax will return index of any nans
10+
# so replace nans with -inf
1211

13-
# pad matrix with -inf
14-
padded = np.ones(tuple(s + 2 for s in matrix.shape), dtype=matrix.dtype)
15-
padded *= -np.inf
16-
slices = tuple(slice(1, s + 1) for s in matrix.shape)
17-
padded[slices] = matrix
12+
mask = np.isnan(temp_matrix)
1813

19-
padded_indices = scipy.signal.argrelmax(np.asarray(padded), mode="wrap")
20-
# indices into original matrix without padding
21-
indices = tuple(i - 1 for i in padded_indices)
14+
temp_matrix[mask] = -np.inf
2215

23-
# set all non-maxima to -inf
24-
maxima = np.ones_like(matrix)
25-
maxima *= -np.inf
26-
maxima[indices] = matrix[indices]
16+
index = np.argmax(temp_matrix)
2717

28-
# get all maximum values, sorted
29-
values, indices = np.unique(maxima, return_index=True)
18+
index = np.unravel_index(index, temp_matrix.shape)
3019

31-
assert values[0] == -np.inf
32-
33-
# first index should point to -np.inf, drop it
34-
indices = indices[1:]
35-
36-
# values, indices = np.unique(np.isfinite(maxima), return_index=True)
37-
38-
# retain only last num_maxima values
39-
if len(indices) > num_maxima:
40-
indices = indices[-num_maxima:]
41-
42-
# convert back to non-flattened indices
43-
indices = np.unravel_index(indices, matrix.shape)
44-
45-
# convert to jax array
46-
indices = tuple(jnp.array(i) for i in indices)
47-
48-
return indices
20+
return tuple(jnp.expand_dims(jnp.array(i), axis=-1) for i in index)

0 commit comments

Comments
 (0)