@@ -39,7 +39,7 @@ def __init__(self, contacts, nproc=1, ncomp=15, niter=110000):
39
39
self .ncomp = ncomp
40
40
self .contacts = contacts
41
41
42
- def run (self , run_resids = None ):
42
+ def run (self , run_resids = None , g = 100 ):
43
43
"""
44
44
The :meth:`run` method executes the Gibbs samplers for all residues of
45
45
`sel1` present in the contact map, or a list of resids can be provided.
@@ -71,7 +71,7 @@ def run(self, run_resids=None):
71
71
run_resids ])
72
72
residues = residues [inds ]
73
73
input_list = [[residues [i ], times [i ].copy (), i % self .nproc ,
74
- self .ncomp , self .niter , self .cutoff ] for i in
74
+ self .ncomp , self .niter , self .cutoff , g ] for i in
75
75
range (len (residues ))]
76
76
77
77
del contacts , times
@@ -218,7 +218,7 @@ def run(self):
218
218
219
219
self .save ()
220
220
221
- def cluster (self , method = "GaussianMixture" , ** kwargs ):
221
+ def cluster (self , method = "GaussianMixture" , g = self . g , ** kwargs ):
222
222
r"""
223
223
Cluster the processed results using the methods available in
224
224
:class:`sklearn.mixture`
@@ -230,7 +230,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
230
230
from scipy import stats
231
231
232
232
clu = getattr (mixture , method )
233
- burnin_ind = self .burnin // self . g
233
+ burnin_ind = self .burnin // g
234
234
data_len = len (self .times )
235
235
wcutoff = 10 / data_len
236
236
@@ -256,7 +256,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
256
256
r .fit (np .log (train_data ))
257
257
all_labels = r .predict (np .log (data ))
258
258
259
- if self .indicator is not None :
259
+ if ( self .indicator is not None ) and g == self . g :
260
260
indicator = self .indicator [burnin_ind :]
261
261
else :
262
262
indicator = self ._sample_indicator ()
@@ -272,7 +272,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
272
272
setattr (self .processed_results , 'indicator' , pindicator )
273
273
setattr (self .processed_results , 'labels' , all_labels )
274
274
275
- def process_gibbs (self ):
275
+ def process_gibbs (self , g = self . g ):
276
276
r"""
277
277
Process the samples collected from the Gibbs sampler.
278
278
:meth:`process_gibbs` can be called multiple times to check the
@@ -283,10 +283,10 @@ def process_gibbs(self):
283
283
284
284
data_len = len (self .times )
285
285
wcutoff = 10 / data_len
286
- burnin_ind = self .burnin // self . g
286
+ burnin_ind = self .burnin // g
287
287
inds = np .where (self .mcweights [burnin_ind :] > wcutoff )
288
- indices = (np .arange (self .burnin , self .niter + 1 , self . g )[inds [0 ]] //
289
- self . g )
288
+ indices = (np .arange (self .burnin , self .niter + 1 , g )[inds [0 ]] //
289
+ g )
290
290
weights , rates = self .mcweights [burnin_ind :], self .mcrates [burnin_ind :]
291
291
fweights , frates = weights [inds ], rates [inds ]
292
292
@@ -318,10 +318,10 @@ def result_plot(self, remove_noise=False, **kwargs):
318
318
from basicrta .util import mixture_and_plot
319
319
mixture_and_plot (self , remove_noise = remove_noise , ** kwargs )
320
320
321
- def _sample_indicator (self ):
322
- indicator = np .zeros (((self .niter + 1 )// self . g , self .times .shape [0 ]),
321
+ def _sample_indicator (self , g = self . g ):
322
+ indicator = np .zeros (((self .niter + 1 )// g , self .times .shape [0 ]),
323
323
dtype = np .uint8 )
324
- burnin_ind = self .burnin // self . g
324
+ burnin_ind = self .burnin // g
325
325
for i , (w , r ) in enumerate (zip (self .mcweights , self .mcrates )):
326
326
# compute probabilities
327
327
probs = w * r * np .exp (np .outer (- r , self .times )).T
0 commit comments