Skip to content

Commit fc543fe

Browse files
committed
fixed conflicts
1 parent 7a0ad41 commit fc543fe

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

basicrta/gibbs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, contacts, nproc=1, ncomp=15, niter=110000):
3939
self.ncomp = ncomp
4040
self.contacts = contacts
4141

42-
def run(self, run_resids=None):
42+
def run(self, run_resids=None, g=100):
4343
"""
4444
The :meth:`run` method executes the Gibbs samplers for all residues of
4545
`sel1` present in the contact map, or a list of resids can be provided.
@@ -71,7 +71,7 @@ def run(self, run_resids=None):
7171
run_resids])
7272
residues = residues[inds]
7373
input_list = [[residues[i], times[i].copy(), i % self.nproc,
74-
self.ncomp, self.niter, self.cutoff] for i in
74+
self.ncomp, self.niter, self.cutoff, g] for i in
7575
range(len(residues))]
7676

7777
del contacts, times
@@ -218,7 +218,7 @@ def run(self):
218218

219219
self.save()
220220

221-
def cluster(self, method="GaussianMixture", **kwargs):
221+
def cluster(self, method="GaussianMixture", g=self.g, **kwargs):
222222
r"""
223223
Cluster the processed results using the methods available in
224224
:class:`sklearn.mixture`
@@ -230,7 +230,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
230230
from scipy import stats
231231

232232
clu = getattr(mixture, method)
233-
burnin_ind = self.burnin // self.g
233+
burnin_ind = self.burnin // g
234234
data_len = len(self.times)
235235
wcutoff = 10 / data_len
236236

@@ -256,7 +256,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
256256
r.fit(np.log(train_data))
257257
all_labels = r.predict(np.log(data))
258258

259-
if self.indicator is not None:
259+
if (self.indicator is not None) and g==self.g:
260260
indicator = self.indicator[burnin_ind:]
261261
else:
262262
indicator = self._sample_indicator()
@@ -272,7 +272,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
272272
setattr(self.processed_results, 'indicator', pindicator)
273273
setattr(self.processed_results, 'labels', all_labels)
274274

275-
def process_gibbs(self):
275+
def process_gibbs(self, g=self.g):
276276
r"""
277277
Process the samples collected from the Gibbs sampler.
278278
:meth:`process_gibbs` can be called multiple times to check the
@@ -283,10 +283,10 @@ def process_gibbs(self):
283283

284284
data_len = len(self.times)
285285
wcutoff = 10/data_len
286-
burnin_ind = self.burnin//self.g
286+
burnin_ind = self.burnin//g
287287
inds = np.where(self.mcweights[burnin_ind:] > wcutoff)
288-
indices = (np.arange(self.burnin, self.niter + 1, self.g)[inds[0]] //
289-
self.g)
288+
indices = (np.arange(self.burnin, self.niter + 1, g)[inds[0]] //
289+
g)
290290
weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:]
291291
fweights, frates = weights[inds], rates[inds]
292292

@@ -318,10 +318,10 @@ def result_plot(self, remove_noise=False, **kwargs):
318318
from basicrta.util import mixture_and_plot
319319
mixture_and_plot(self, remove_noise=remove_noise, **kwargs)
320320

321-
def _sample_indicator(self):
322-
indicator = np.zeros(((self.niter+1)//self.g, self.times.shape[0]),
321+
def _sample_indicator(self, g=self.g):
322+
indicator = np.zeros(((self.niter+1)//g, self.times.shape[0]),
323323
dtype=np.uint8)
324-
burnin_ind = self.burnin//self.g
324+
burnin_ind = self.burnin//g
325325
for i, (w, r) in enumerate(zip(self.mcweights, self.mcrates)):
326326
# compute probabilities
327327
probs = w*r*np.exp(np.outer(-r, self.times)).T

0 commit comments

Comments
 (0)