Skip to content

Commit 2a46b83

Browse files
committed
fixed broken module options
1 parent d3b0482 commit 2a46b83

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

basicrta/gibbs.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,14 @@ class Gibbs(object):
131131
"""
132132

133133
def __init__(self, times=None, residue=None, loc=0, ncomp=15, niter=110000,
134-
cutoff=None, g=50, burnin=10000):
134+
cutoff=None, g=50, burnin=10000, gskip=2):
135135
self.times = times
136136
self.residue = residue
137137
self.niter = niter
138138
self.loc = loc
139139
self.ncomp = ncomp
140140
self.g = g
141+
self.gskip = gskip
141142
self.burnin = burnin
142143
self.cutoff = cutoff
143144
self.processed_results = Results()
@@ -218,7 +219,7 @@ def run(self):
218219

219220
self.save()
220221

221-
def cluster(self, method="GaussianMixture", g=self.g, **kwargs):
222+
def cluster(self, method="GaussianMixture", **kwargs):
222223
r"""
223224
Cluster the processed results using the methods available in
224225
:class:`sklearn.mixture`
@@ -256,7 +257,7 @@ def cluster(self, method="GaussianMixture", g=self.g, **kwargs):
256257
r.fit(np.log(train_data))
257258
all_labels = r.predict(np.log(data))
258259

259-
if (self.indicator is not None) and g==self.g:
260+
if self.indicator is not None:
260261
indicator = self.indicator[burnin_ind:]
261262
else:
262263
indicator = self._sample_indicator()
@@ -272,7 +273,7 @@ def cluster(self, method="GaussianMixture", g=self.g, **kwargs):
272273
setattr(self.processed_results, 'indicator', pindicator)
273274
setattr(self.processed_results, 'labels', all_labels)
274275

275-
def process_gibbs(self, g=self.g):
276+
def process_gibbs(self):
276277
r"""
277278
Process the samples collected from the Gibbs sampler.
278279
:meth:`process_gibbs` can be called multiple times to check the
@@ -283,10 +284,10 @@ def process_gibbs(self, g=self.g):
283284

284285
data_len = len(self.times)
285286
wcutoff = 10/data_len
286-
burnin_ind = self.burnin//g
287+
burnin_ind = self.burnin//self.g
287288
inds = np.where(self.mcweights[burnin_ind:] > wcutoff)
288-
indices = (np.arange(self.burnin, self.niter + 1, g)[inds[0]] //
289-
g)
289+
indices = (np.arange(self.burnin, self.niter + 1, self.g*self.gskip)
290+
[inds[0]] // self.g)
290291
weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:]
291292
fweights, frates = weights[inds], rates[inds]
292293

@@ -318,10 +319,10 @@ def result_plot(self, remove_noise=False, **kwargs):
318319
from basicrta.util import mixture_and_plot
319320
mixture_and_plot(self, remove_noise=remove_noise, **kwargs)
320321

321-
def _sample_indicator(self, g=self.g):
322-
indicator = np.zeros(((self.niter+1)//g, self.times.shape[0]),
322+
def _sample_indicator(self):
323+
indicator = np.zeros(((self.niter+1)//self.g, self.times.shape[0]),
323324
dtype=np.uint8)
324-
burnin_ind = self.burnin//g
325+
burnin_ind = self.burnin//self.g
325326
for i, (w, r) in enumerate(zip(self.mcweights, self.mcrates)):
326327
# compute probabilities
327328
probs = w*r*np.exp(np.outer(-r, self.times)).T

0 commit comments

Comments
 (0)