Skip to content

Commit ea4e514

Browse files
committed
renamed variables in Gibbs sampler to match paper
1 parent 2a46b83 commit ea4e514

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

basicrta/gibbs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,29 +193,29 @@ def run(self):
193193
desc=f'{self.residue}-K{self.ncomp}',
194194
position=self.loc, leave=False):
195195

196-
# compute probabilities
196+
# compute probabilities (equation 9)
197197
tmp = weights*rates*np.exp(np.outer(-rates, self.times)).T
198-
z = (tmp.T/tmp.sum(axis=1)).T
198+
psample = (tmp.T/tmp.sum(axis=1)).T
199199

200200
# sample indicator
201-
s = np.argmax(rng.multinomial(1, z), axis=1)
201+
z = np.argmax(rng.multinomial(1, psample), axis=1)
202202

203203
# get indicator for each data point
204-
inds = [np.where(s == i)[0] for i in range(self.ncomp)]
204+
inds = [np.where(z == i)[0] for i in range(self.ncomp)]
205205

206206
# compute total time and number of point for each component
207207
Ns = np.array([len(inds[i]) for i in range(self.ncomp)])
208208
Ts = np.array([self.times[inds[i]].sum() for i in range(self.ncomp)])
209209

210-
# sample posteriors
210+
# sample posteriors (equations 7 and 8)
211211
weights = rng.dirichlet(self.whypers+Ns)
212212
rates = rng.gamma(self.rhypers[:, 0]+Ns, 1/(self.rhypers[:, 1]+Ts))
213213

214214
# save every g steps
215215
if j % self.g == 0:
216216
ind = j//self.g-1
217217
self.mcweights[ind], self.mcrates[ind] = weights, rates
218-
self.indicator[ind] = s
218+
self.indicator[ind] = z
219219

220220
self.save()
221221

@@ -231,7 +231,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
231231
from scipy import stats
232232

233233
clu = getattr(mixture, method)
234-
burnin_ind = self.burnin // g
234+
burnin_ind = self.burnin // (self.g*self.gskip)
235235
data_len = len(self.times)
236236
wcutoff = 10 / data_len
237237

0 commit comments

Comments
 (0)