Skip to content

Commit 1108628

Browse files
committed
Added backup plan for changing slopes
1 parent 67fa147 commit 1108628

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

navicat_spock/piecewise_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def summary(self):
11001100

11011101
class ModelSelection:
11021102
"""
1103-
Experimental - uses simple BIC based on simple linear model.
1103+
Uses BIC to compare models with different number of breakpoints.
11041104
"""
11051105

11061106
def __init__(

navicat_spock/spock.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def run_spock_from_args(
123123
all_bic = np.zeros((len(idxs), max_breakpoints + 1), dtype=float)
124124
all_n = np.zeros((len(idxs), max_breakpoints + 1), dtype=int)
125125
all_sc = np.zeros((len(idxs), max_breakpoints + 1), dtype=bool)
126+
msels = []
126127
for i, idx in enumerate(idxs):
127128
fitted = False
128129
try:
@@ -139,6 +140,7 @@ def run_spock_from_args(
139140
tolerance=xthresh,
140141
verbose=verb > 2,
141142
)
143+
msels.append(msel)
142144
all_sc[i, :] = np.array(
143145
[
144146
slope_check(summary["slopes"], verb)
@@ -226,6 +228,7 @@ def run_spock_from_args(
226228
print(
227229
f"Fit did not converge with descriptor index {idx}: {tags[idx]}\n due to {m}"
228230
)
231+
msels.append(None)
229232

230233
# Done iterating over descriptors
231234
best_n = np.zeros_like(idxs, dtype=int)
@@ -293,17 +296,21 @@ def run_spock_from_args(
293296
f"Fitting volcano with {n} breakpoints and descriptor index {idx}: {tags[idx]}, as determined from BIC."
294297
)
295298
descriptor = d[:, idx].reshape(-1)
296-
xrange = 0.05 * (max(descriptor) - min(descriptor))
299+
xthresh = 0.05 * (max(descriptor) - min(descriptor))
297300
pw_fit = Fit(
298301
descriptor,
299302
target,
300303
n_breakpoints=n,
301304
weights=weights,
302-
max_iterations=5000,
303-
tolerance=xrange,
305+
max_iterations=n_iter_helper(fitted),
306+
tolerance=xthresh,
304307
)
305308
if not pw_fit.best_muggeo:
306-
raise ConvergenceError("The fitting process did not converge.")
309+
# If for some reason the fit fails now, we use the preliminary fit instead
310+
pw_fit.best_muggeo = msels[idx - 1].models[n - 1].best_muggeo
311+
if not slope_check(pw_fit.get_results()["slopes"], verb):
312+
# If for some reason the fit switched the slopes, we use the preliminary fit instead
313+
pw_fit.best_muggeo = msels[idx - 1].models[n - 1].best_muggeo
307314
if verb > 2:
308315
pw_fit.summary()
309316
# Plot the data, fit, breakpoints and confidence intervals

0 commit comments

Comments
 (0)