diff --git a/protons/app/analysis.py b/protons/app/analysis.py index a721f41..27afa3f 100644 --- a/protons/app/analysis.py +++ b/protons/app/analysis.py @@ -99,6 +99,52 @@ def bar_all_states(dataset: netCDF4.Dataset, bootstrap: bool = False, num_bootst return bars_per_transition +def extract_work_distributions(dataset: netCDF4.Dataset, state1_idx: int, state2_idx: int, res_idx: int) -> tuple: + """Extract the forward and reverse work distributions for ncmc protocols between two states, for a given residue. + + Parameters + ---------- + dataset - a Dataset with a Protons/NCMC group to analyze + state1_idx - the "from" state index, for defining forward protocol + state2_idx - the "to" state index, for defining forward protocol + res_idx - the titratable residue index + + Returns + ------- + tuple(np.ndarray, np.ndarray) the forward, and reverse work distributions. + + Note + ---- + The distribution of the reverse proposals is returned as -W, to give it the same sign + as the forward distribution. + """ + + ncmc = dataset["Protons/NCMC"] + forward_work = [] + neg_reverse_work = [] + + initial_states = ncmc["initial_state"][:,res_idx] + proposed_states = ncmc["proposed_state"][:,res_idx] + tot_work = ncmc["total_work"][:] + for update in ncmc["update"]: + update -= 1 # 1 indexed variable + init = initial_states[update] + prop = proposed_states[update] + + # Forward distribution + if init == state1_idx: + if prop == state2_idx: + forward_work.append(tot_work[update]) + # Reverse distribution + elif init == state2_idx: + if prop == state1_idx: + # Use negative value of the work + # so that the two distributions have the same sign. + neg_reverse_work.append(-tot_work[update]) + + return np.asarray(forward_work), np.asarray(neg_reverse_work) + + def _nonparametric_bootstrap_bar(forward: np.ndarray, reverse: np.ndarray, nbootstraps: int, sams_estimate: float): """Perform sampling with replacement on forward and reverse trajectories and perform BAR. diff --git a/protons/tests/test_analysis.py b/protons/tests/test_analysis.py index 91cce73..dba88fd 100644 --- a/protons/tests/test_analysis.py +++ b/protons/tests/test_analysis.py @@ -97,3 +97,8 @@ def test_taut_heatmap(self): """Plot the heatmap of residue charges""" with netCDF4.Dataset(self.abl_imatinib_netcdf, 'r') as dataset: analysis.plot_tautomer_heatmap(dataset, residues=None) + + def test_extracting_work(self): + """Extract work distributions for a single residue""" + with netCDF4.Dataset(self.abl_imatinib_netcdf, 'r') as dataset: + analysis.extract_work_distributions(dataset, 0, 1, -1)