From a731bb33fe598713ec1fbf1ccaa58bec5225716f Mon Sep 17 00:00:00 2001 From: Rick Sexton Date: Sun, 28 Jan 2024 17:55:27 -0700 Subject: [PATCH] added docs --- INSTALL.md | 2 +- LICENSE | 1 + basicrta/functions.py | 139 +++++++++++++++++++++++------------------- 3 files changed, 77 insertions(+), 65 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index e1ccb54..e665bb1 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,5 +1,5 @@ conda create -n basicrta python=3.8 conda activate basicrta conda install mamba -mamba install numpy tqdm matplotlib MDAnalysis scipy pandas seaborn ipython jupyter pymbar +mamba install -c conda-forge numpy tqdm matplotlib MDAnalysis scipy seaborn pip install . diff --git a/LICENSE b/LICENSE index f288702..53d1f3d 100644 --- a/LICENSE +++ b/LICENSE @@ -672,3 +672,4 @@ may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . + diff --git a/basicrta/functions.py b/basicrta/functions.py index 9945a99..40074e2 100644 --- a/basicrta/functions.py +++ b/basicrta/functions.py @@ -1,25 +1,21 @@ +"""Analysis functions""" + from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, \ AutoMinorLocator) from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection -import ast -import multiprocessing +import ast, multiprocessing, os import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt -import os -import pymbar.timeseries as pmts from MDAnalysis.analysis.base import Results -import pickle +import pickle, bz2, gc from glob import glob import seaborn as sns -import math from numpy.random import default_rng from tqdm import tqdm import MDAnalysis as mda -import gc from scipy.optimize import linear_sum_assignment as lsa -import bz2 from scipy import stats from sklearn.cluster import KMeans @@ -76,15 +72,21 @@ def tm(Prot,i): class gibbs(object): - def __init__(self, times, residue, loc=0, ncomp=15, niter=10000): + """Gibbs sampler to estimate parameters of an exponential mixture for a set + of data. Results are stored in gibbs.results, which uses + MDAnalysis.analysis.base.Results(). If 'results=None' the gibbs sampler has + not been executed, which requires calling '.run()' + + """ + + def __init__(self, times, residue, loc=0, ncomp=15, niter=50000): self.times, self.residue = times, residue self.niter, self.loc, self.ncomp = niter, loc, ncomp + self.results = None diff = (np.sort(times)[1:]-np.sort(times)[:-1]) self.ts = diff[diff!=0][0] - def __repr__(self): - return f'Gibbs sampler' def __str__(self): @@ -93,6 +95,7 @@ def __str__(self): def run(self): x = self.times + g = 100 t, _s = get_s(x, self.ts) if not os.path.exists(f'{self.residue}'): os.mkdir(f'{self.residue}') @@ -100,54 +103,59 @@ def run(self): # initialize arrays inrates = 0.5*10**np.arange(-self.ncomp+2, 2, dtype=float) indicator = np.memmap(f'{self.residue}/.indicator_{self.niter}.npy', \ - shape=(self.niter, x.shape[0]), mode='w+', \ + shape=((self.niter+1)//g, x.shape[0]), mode='w+',\ dtype=np.uint8) - mcweights = np.zeros((self.niter + 1, self.ncomp)) - mcrates = np.zeros((self.niter + 1, self.ncomp)) - Ns, lnp = np.zeros((self.niter, self.ncomp)), np.zeros(self.niter) + mcweights = np.zeros(((self.niter+1)//g, self.ncomp)) + mcrates = np.zeros(((self.niter+1)//g, self.ncomp)) + #lnp = np.zeros(self.niter) tmpw = 9*10**(-np.arange(1, self.ncomp+1, dtype=float)) - mcweights[0], mcrates[0] = tmpw/tmpw.sum(), inrates[::-1] + weights, rates = tmpw/tmpw.sum(), inrates[::-1] # guess hyperparameters whypers = np.ones(self.ncomp)/[self.ncomp] rhypers = np.ones((self.ncomp, 2))*[1, 3] # gibbs sampler - for j in tqdm(range(self.niter), desc=f'{self.residue}-K{self.ncomp}', \ + for j in tqdm(range(1, self.niter+1), \ + desc=f'{self.residue}-K{self.ncomp}', \ position=self.loc, leave=False): # compute probabilities - tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T + tmp = weights*rates*np.exp(np.outer(-rates,x)).T z = (tmp.T/tmp.sum(axis=1)).T - # sample and store indicator + # sample indicator s = np.argmax(rng.multinomial(1, z), axis=1) - indicator[j] = s # get occupied states uniqs = np.unique(s) inds = [np.where(s==i)[0] for i in range(self.ncomp)] # compute total time and number of point for each component - Ns[j][:] = np.array([len(inds[i]) for i in range(self.ncomp)]) + Ns = np.array([len(inds[i]) for i in range(self.ncomp)]) Ts = np.array([x[inds[i]].sum() for i in range(self.ncomp)]) # compute log posterior - lnp[j] = np.log(tmp.take(s)).sum()+\ - np.log(mcweights[j][uniqs]).sum()-\ - (mcrates[j][uniqs]*rhypers[uniqs, 1]).sum()+\ - np.log(mcweights[j][uniqs]**(whypers[uniqs]-1)).sum() + #lnp[j] = np.log(tmp.take(s)).sum()+\ + # np.log(mcweights[j][uniqs]).sum()-\ + # (mcrates[j][uniqs]*rhypers[uniqs, 1]).sum()+\ + # np.log(mcweights[j][uniqs]**(whypers[uniqs]-1)).sum() # sample posteriors - mcweights[j+1] = rng.dirichlet(whypers+Ns[j]) - mcrates[j+1] = rng.gamma(rhypers[:,0]+Ns[j], 1/(rhypers[:,1]+Ts)) + weights = rng.dirichlet(whypers+Ns) + rates = rng.gamma(rhypers[:,0]+Ns, 1/(rhypers[:,1]+Ts)) + + # save every g steps + if j%g==0: + ind = j//g-1 + mcweights[ind], mcrates[ind] = weights, rates + indicator[ind] = s - attrs = ["mcweights", "mcrates", "ncomp", "niter", "s", "t", "residue", - "lnp", "times"] + "times"] values = [mcweights, mcrates, self.ncomp, self.niter, _s, t, - self.residue, lnp, x] + self.residue, x] r = save_results(attrs, values) @@ -191,30 +199,29 @@ def estimate_params(processed_results): def process_gibbs(results, cutoff=1e-4): r = results - - #burnin, g, nsample = pmts.detect_equilibration(r.lnp, nskip=20) - burnin, g = int(burnin), int(np.ceil(g)) - - if burnin==0: - burnin = 500 - else: - burnin = burnin + burnin, g = 10000, 100 + burnin_ind = burnin//g - inds = np.where(r.mcweights[burnin::g]>cutoff) - indices = np.arange(burnin, r.niter, g)[inds[0]] - H = np.histogram([len(row[row>cutoff]) for row in r.mcweights[burnin::g]], \ - bins=np.arange(1, r.ncomp+1)) - ncomp = int(H[1][:-1][H[0]==H[0].max()][0]) - - weights, rates = r.mcweights[burnin::g][inds], r.mcrates[burnin::g][inds] - lnp = r.lnp[burnin::g][inds[0]] + inds = np.where(r.mcweights[burnin_ind:]>cutoff) + indices = np.arange(burnin, r.niter+1, g)[inds[0]]//g + lens = [len(row[row>cutoff]) for row in r.mcweights[burnin_ind:]] + ncomp = stats.mode(lens, keepdims=False)[0] + + weights = r.mcweights[burnin_ind::][inds] + rates = r.mcrates[burnin_ind::][inds] + #lnp = r.lnp[burnin::g][inds[0]] + data = np.stack((weights, rates), axis=1) + km = KMeans(n_clusters=ncomp).fit(np.log(data)) Indicator = np.zeros((r.times.shape[0], ncomp)) - - for j,iteration in enumerate(np.unique(indices)): + indicator = np.memmap(f'{r.residue}/.indicator_{r.niter}.npy', \ + shape=((r.niter+1)//g, r.times.shape[0]), mode='r', \ + dtype=np.uint8) + + for j in np.unique(inds[0]): mapinds = km.labels_[inds[0]==j] - for i,indx in enumerate(inds[1][indices==iteration]): - tmpind = np.where(indicator[iteration]==indx)[0] + for i,indx in enumerate(inds[1][inds[0]==j]): + tmpind = np.where(indicator[j]==indx)[0] Indicator[tmpind, mapinds[i]] += 1 Indicator = (Indicator.T/Indicator.sum(axis=1)).T @@ -377,7 +384,8 @@ def plot_post(results, attr, comp=None, save=False, show=False): plt.show() -def plot_trace(results, attr, comp=None, xrange=None, yrange=None, save=False, show=False): +def plot_trace(results, attr, comp=None, xrange=None, yrange=None, save=False, \ + show=False): outdir = results.name if attr=='weights': tmp = getattr(results, 'mcweights') @@ -409,8 +417,10 @@ def plot_trace(results, attr, comp=None, xrange=None, yrange=None, save=False, s if yrange!=None: plt.ylim(yrange[0], yrange[1]) if save: - plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-{"-".join([str(i) for i in comp])}.png') - plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-{"-".join([str(i) for i in comp])}.pdf') + plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-\ + {"-".join([str(i) for i in comp])}.png') + plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-\ + {"-".join([str(i) for i in comp])}.pdf') if show: plt.show() plt.close('all') @@ -590,8 +600,10 @@ def check_results(residues, times, ts): os.mkdir('result_check') for time, residue in zip(times, residues): if os.path.exists(residue): - kmax = glob(f'{residue}/K*_results.pkl')[-1].split('/')[-1].split('/')[-1].split('_')[0][1:] - os.popen(f'cp {residue}/figs/k{kmax}-mean_results.png result_check/{residue}-k{kmax}-results.png') + kmax = glob(f'{residue}/K*_results.pkl')[-1].split('/')[-1].\ + split('/')[-1].split('_')[0][1:] + os.popen(f'cp {residue}/figs/k{kmax}-mean_results.png result_check/\ + {residue}-k{kmax}-results.png') else: t, s = get_s(np.array(time), ts) plt.scatter(t, s, label='data') @@ -643,7 +655,8 @@ def write_trajs(u, time, trajtime, indicator, residue, lipind, step): for comp in np.where(lens != 0)[0]: write_frames, write_Linds = get_write_frames(u, time, trajtime, lipind, comp+2) if len(write_frames) > step: - write_frames, write_Linds = write_frames[::step], write_Linds[::step] + write_frames = write_frames[::step] + write_Linds = write_Linds[::step] with mda.Writer(f"{residue}/comp{comp}_traj.xtc", \ len((prot+chol.residues[0].atoms).atoms)) as W: for i, ts in tqdm(enumerate(u.trajectory[write_frames]), \ @@ -654,18 +667,15 @@ def write_trajs(u, time, trajtime, indicator, residue, lipind, step): def plot_hists(timelens, indicators, residues): - for timelen, indicator, residue in tqdm(zip(timelens, indicators, residues), total=len(timelens), + for timelen, indicator, residue in tqdm(zip(timelens, indicators, residues), + total=len(timelens), desc='ploting hists'): - # framec = (np.round(timelen, 1) * 10).astype(int) - #inds = np.array([np.where(indicator.argmax(axis=0) == i)[0] for i in range(8)], dtype=object) - #lens = np.array([len(ind) for ind in inds]) - #ncomps = len(np.where(lens != 0)[0]) ncomps = indicator[:,0].shape[0] plt.close() for i in range(ncomps): - # h, edges = np.histogram(framec, density=True, bins=50, weights=indicator[i]) - h, edges = np.histogram(timelen, density=True, bins=50, weights=indicator[i]) + h, edges = np.histogram(timelen, density=True, bins=50, \ + weights=indicator[i]) m = 0.5*(edges[1:]+edges[:-1]) plt.plot(m, h, '.', label=i, alpha=0.5) plt.ylabel('p') @@ -738,7 +748,8 @@ def expand_times(contacts): restimes = [] for lip in range(times.shape[1]): for i in range(times[res, lip].shape[0]): - [restimes.append(j) for j in [times[res, lip][i]]*Ns[res, lip][i].astype(int)] + [restimes.append(j) for j in [times[res, lip][i]]*\ + Ns[res, lip][i].astype(int)] alltimes.append(restimes) return np.asarray(alltimes)