@@ -131,13 +131,14 @@ class Gibbs(object):
131
131
"""
132
132
133
133
def __init__ (self , times = None , residue = None , loc = 0 , ncomp = 15 , niter = 110000 ,
134
- cutoff = None , g = 50 , burnin = 10000 ):
134
+ cutoff = None , g = 50 , burnin = 10000 , gskip = 2 ):
135
135
self .times = times
136
136
self .residue = residue
137
137
self .niter = niter
138
138
self .loc = loc
139
139
self .ncomp = ncomp
140
140
self .g = g
141
+ self .gskip = gskip
141
142
self .burnin = burnin
142
143
self .cutoff = cutoff
143
144
self .processed_results = Results ()
@@ -218,7 +219,7 @@ def run(self):
218
219
219
220
self .save ()
220
221
221
- def cluster (self , method = "GaussianMixture" , g = self . g , ** kwargs ):
222
+ def cluster (self , method = "GaussianMixture" , ** kwargs ):
222
223
r"""
223
224
Cluster the processed results using the methods available in
224
225
:class:`sklearn.mixture`
@@ -256,7 +257,7 @@ def cluster(self, method="GaussianMixture", g=self.g, **kwargs):
256
257
r .fit (np .log (train_data ))
257
258
all_labels = r .predict (np .log (data ))
258
259
259
- if ( self .indicator is not None ) and g == self . g :
260
+ if self .indicator is not None :
260
261
indicator = self .indicator [burnin_ind :]
261
262
else :
262
263
indicator = self ._sample_indicator ()
@@ -272,7 +273,7 @@ def cluster(self, method="GaussianMixture", g=self.g, **kwargs):
272
273
setattr (self .processed_results , 'indicator' , pindicator )
273
274
setattr (self .processed_results , 'labels' , all_labels )
274
275
275
- def process_gibbs (self , g = self . g ):
276
+ def process_gibbs (self ):
276
277
r"""
277
278
Process the samples collected from the Gibbs sampler.
278
279
:meth:`process_gibbs` can be called multiple times to check the
@@ -283,10 +284,10 @@ def process_gibbs(self, g=self.g):
283
284
284
285
data_len = len (self .times )
285
286
wcutoff = 10 / data_len
286
- burnin_ind = self .burnin // g
287
+ burnin_ind = self .burnin // self . g
287
288
inds = np .where (self .mcweights [burnin_ind :] > wcutoff )
288
- indices = (np .arange (self .burnin , self .niter + 1 , g )[ inds [ 0 ]] //
289
- g )
289
+ indices = (np .arange (self .burnin , self .niter + 1 , self . g * self . gskip )
290
+ [ inds [ 0 ]] // self . g )
290
291
weights , rates = self .mcweights [burnin_ind :], self .mcrates [burnin_ind :]
291
292
fweights , frates = weights [inds ], rates [inds ]
292
293
@@ -318,10 +319,10 @@ def result_plot(self, remove_noise=False, **kwargs):
318
319
from basicrta .util import mixture_and_plot
319
320
mixture_and_plot (self , remove_noise = remove_noise , ** kwargs )
320
321
321
- def _sample_indicator (self , g = self . g ):
322
- indicator = np .zeros (((self .niter + 1 )// g , self .times .shape [0 ]),
322
+ def _sample_indicator (self ):
323
+ indicator = np .zeros (((self .niter + 1 )// self . g , self .times .shape [0 ]),
323
324
dtype = np .uint8 )
324
- burnin_ind = self .burnin // g
325
+ burnin_ind = self .burnin // self . g
325
326
for i , (w , r ) in enumerate (zip (self .mcweights , self .mcrates )):
326
327
# compute probabilities
327
328
probs = w * r * np .exp (np .outer (- r , self .times )).T
0 commit comments