diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index 72d149a1..c7fb97e1 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -965,15 +965,76 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None): plt.savefig(f'{swap_type}_swaps.png', dpi=600) -def get_dg_evolution(log_file, start_state, end_state): +def get_g_evolution(log_files, N_states, avg_frac=0): + """ + For weight-updating simulations, gets the time series of the alchemical + weights of all states. + + Parameters + ---------- + log_files : list + The list of log file names. + N_states : int + The total number of states in the whole alchemical range. + avg_frac : float + The fraction of the last part of the simulation to be averaged. The + default is 0, which means no averaging. + + + Returns + ------- + g_vecs_all : list + The alchemical weights of all states as a function of time. + It should be a list of lists. + g_vecs_avg : list + The alchemical weights of all states averaged over the last part of + the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned. + """ + g_vecs_all = [] + for log_file in log_files: + f = open(log_file, "r") + lines = f.readlines() + f.close() + + n = -1 + find_equil = False + for line in lines: + n += 1 + if "Count G(in kT)" in line: # this line is lines[n] + w = [] # the list of weights at this time frame + for i in range(1, N_states + 1): + if "<<" in lines[n + i]: + w.append(float(lines[n + i].split()[-3])) + else: + w.append(float(lines[n + i].split()[-2])) + + if find_equil is False: + g_vecs_all.append(w) + + if "Weights have equilibrated" in line: + find_equil = True + w = [float(i) for i in lines[n - 2].split(':')[-1].split()] + g_vecs_all.append(w) + break + + if avg_frac != 0: + n_avg = int(avg_frac * len(g_vecs_all)) + g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0) + else: + g_vecs_avg = None + + return g_vecs_all, g_vecs_avg + + +def get_dg_evolution(log_files, start_state, end_state): """ For weight-updating simulations, gets the time series of the weight difference (:math:`Δg = g_2-g_1`) between the specified states. Parameters ---------- - log_file : str - The log file name. + log_files : list + The list of log file names. start_state : int The index of the state (starting from 0) whose weight is :math:`g_1`. end_state : int @@ -984,45 +1045,22 @@ def get_dg_evolution(log_file, start_state, end_state): dg : list A list of :math:`Δg` values. """ - f = open(log_file, "r") - lines = f.readlines() - f.close() - - n = -1 - find_equil = False - dg = [] N_states = end_state - start_state + 1 # number of states for the range of insterest - for line in lines: - n += 1 - if "Count G(in kT)" in line: # this line is lines[n] - w = [] # the list of weights at this time frame - for i in range(1, N_states + 1): - if "<<" in lines[n + i]: - w.append(float(lines[n + i].split()[-3])) - else: - w.append(float(lines[n + i].split()[-2])) - - if find_equil is False: - dg.append(w[end_state] - w[start_state]) - - if "Weights have equilibrated" in line: - find_equil = True - w = [float(i) for i in lines[n - 2].split(':')[-1].split()] - dg.append(w[end_state] - w[start_state]) - break + g_vecs = get_g_evolution(log_files, N_states) + dg = [g_vecs[i][end_state] - g_vecs[i][start_state] for i in range(len(g_vecs))] return dg -def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2): +def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2): """ For weight-updating simulations, plots the time series of the weight difference (:math:`Δg = g_2-g_1`) between the specified states. Parameters ---------- - log_file : str or list - The log file name or a list of log file names. + log_files : list + The list of log file names. start_state : int The index of the state (starting from 0) whose weight is :math:`g_1`. end_state : int @@ -1035,12 +1073,7 @@ def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1, The time interval between two consecutive frames in the log file. The default is 2 ps. """ - if isinstance(log_file, str): - dg = get_dg_evolution(log_file, start_state, end_state) - else: - dg = [] - for f in log_file: - dg += get_dg_evolution(f, start_state, end_state) + dg = get_dg_evolution(log_files, start_state, end_state) # Now we plot dg = dg[start_idx:end_idx] diff --git a/ensemble_md/cli/run_EEXE.py b/ensemble_md/cli/run_EEXE.py index f566c4fa..cfc69c82 100644 --- a/ensemble_md/cli/run_EEXE.py +++ b/ensemble_md/cli/run_EEXE.py @@ -146,7 +146,7 @@ def main(): dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(EEXE.n_sim)] log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(EEXE.n_sim)] states_ = EEXE.extract_final_dhdl_info(dhdl_files) - wl_delta, weights_, counts = EEXE.extract_final_log_info(log_files) + wl_delta, weights_, counts_ = EEXE.extract_final_log_info(log_files) print() # 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s) @@ -154,63 +154,64 @@ def main(): # since they are updated by `get_swapping_pattern`. (Even if the function does not explicitly # returns `states_` and `weights_`, `states_` and `weights_` can still be different after # the use of the function.) Therefore, here we create copies for `states_` and `weights_` - # before the use of `get_swapping_pattern`, so we can use them in `histogram_correction`, + # before the use of `get_swapping_pattern`, so we can use them in `weight_correction`, # `combine_weights` and `update_MDP`. states = copy.deepcopy(states_) weights = copy.deepcopy(weights_) + counts = copy.deepcopy(counts_) swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501 - # 3-3. Perform histogram correction/weight combination + # 3-3. Perform weight correction/weight combination if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating print(f'\nCurrent Wang-Landau incrementors: {wl_delta}\n') # (1) First we prepare the weights to be combined. - # Note that although averaged weights are sometimes used for histogram correction/weight combination, + # Note that although averaged weights are sometimes used for weight correction/weight combination, # the final weights are always used for calculating the acceptance ratio. if EEXE.N_cutoff != -1 or EEXE.w_combine is not None: - # Only when histogram correction/weight combination is needed. + # Only when weight correction/weight combination is needed. weights_avg, weights_err = EEXE.get_averaged_weights(log_files) weights_input = EEXE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501 - # (2) Now we perform histogram correction/weight combination. + # (2) Now we perform weight correction/weight combination. # The product of this step should always be named as "weights" to be used in update_MDP if EEXE.N_cutoff != -1 and EEXE.w_combine is not None: # perform both if weights_input is None: - # Then only histogram correction will be performed + # Then only weight correction will be performed print('Note: Weight combination is deactivated because the weights are too noisy.') - weights = EEXE.histogram_correction(weights, counts) - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights + weights = EEXE.weight_correction(weights, counts) + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 else: - weights_preprocessed = EEXE.histogram_correction(weights_input, counts) + weights_preprocessed = EEXE.weight_correction(weights_input, counts) if EEXE.verbose is True: print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 + counts, weights, g_vec = EEXE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 EEXE.g_vecs.append(g_vec) elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None: # only perform weight combination - print('Note: No histogram correction will be performed.') + print('Note: No weight correction will be performed.') if weights_input is None: print('Note: Weight combination is deactivated because the weights are too noisy.') - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: if EEXE.verbose is True: print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse + counts, weights, g_vec = EEXE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501 EEXE.g_vecs.append(g_vec) elif EEXE.N_cutoff != -1 and EEXE.w_combine is None: - # only perform histogram correction + # only perform weight correction print('Note: No weight combination will be performed.') weights = EEXE.histogram_correction(weights_input, counts) - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: - print('Note: No histogram correction will be performed.') + print('Note: No weight correction will be performed.') print('Note: No weight combination will be performed.') - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 # 3-5. Modify the MDP files and swap out the GRO files (if needed) # Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501 diff --git a/ensemble_md/ensemble_EXE.py b/ensemble_md/ensemble_EXE.py index e2f233b6..8a95d8dc 100644 --- a/ensemble_md/ensemble_EXE.py +++ b/ensemble_md/ensemble_EXE.py @@ -15,6 +15,7 @@ import copy import yaml import random +import warnings import importlib import subprocess import numpy as np @@ -237,7 +238,7 @@ def set_params(self, analysis): raise ParameterError("The parameter 's' should be non-negative.") if self.N_cutoff < 0 and self.N_cutoff != -1: - raise ParameterError("The parameter 'N_cutoff' should be non-negative unless no histogram correction is needed, i.e. N_cutoff = -1.") # noqa: E501 + raise ParameterError("The parameter 'N_cutoff' should be non-negative unless no weight correction is needed, i.e. N_cutoff = -1.") # noqa: E501 params_str = ['gro', 'top', 'mdp', 'gmx_executable'] # First check if self.gro and self.top are lists and check their lengths @@ -317,7 +318,7 @@ def set_params(self, analysis): if self.fixed_weights is True: if self.N_cutoff != -1 or self.w_combine is not None: - self.warnings.append('Warning: The histogram correction/weight combination method is specified but will not be used since the weights are fixed.') # noqa: E501 + self.warnings.append('Warning: The weight correction/weight combination method is specified but will not be used since the weights are fixed.') # noqa: E501 # In the case that the warning is ignored, enforce the defaults. self.N_cutoff = -1 self.w_combine = None @@ -1096,7 +1097,7 @@ def accept_or_reject(self, prob_acc): print(" Swap rejected! ") return swap_bool - def histogram_correction(self, weights, counts): + def weight_correction(self, weights, counts): """ Corrects the lambda weights based on the histogram counts. Namely, :math:`g_k' = g_k + ln(N_{k-1}/N_k)`, where :math:`g_k` and :math:`g_k'` @@ -1119,9 +1120,9 @@ def histogram_correction(self, weights, counts): An updated list of lists of corected weights. """ if self.verbose is True: - print("\nPerforming histogram correction for the lambda weights ...") + print("\nPerforming weight correction for the lambda weights ...") else: - print("\nPerforming histogram correction for the lambda weights ...", end="") + print("\nPerforming weight correction for the lambda weights ...", end="") for i in range(len(weights)): # loop over the replicas if self.verbose is True: @@ -1216,67 +1217,108 @@ def prepare_weights(self, weights_avg, weights_final): return weights_output - def combine_weights(self, weights, weights_err=None, print_weights=True): + def combine_weights(self, hist, weights, weights_err=None, print_values=True): """ - Combine alchemical weights across multiple replicas. Note that if - :code:`weights_err` is provided, inverse-variance weighting will be used. + Combine alchemical weights across multiple replicas and adjusts the histogram counts + corerspondingly. Note that if :code:`weights_err` is provided, inverse-variance weighting will be used. Care must be taken since inverse-variance weighting can lead to slower convergence if the provided errors are not accurate. (See :ref:`doc_w_schemes` for mor details.) Parameters ---------- + hist : list + A list of lists of histogram counts of ALL simulations. weights : list - A list of Wang-Landau weights of ALL simulations. + A list of lists alchemical weights of ALL simulations. + weights_err : list, optional + A list of lists of errors corresponding to the values in :code:`weights`. + print_values : bool, optional + Whether to print the histograms and weights for each replica before and + after weight combinationfor each replica. Returns ------- + hist_modified : list + A list of modified histogram counts of ALL simulations. weights_modified : list A list of modified Wang-Landau weights of ALL simulations. g_vec : np.ndarray An array of alchemical weights of the whole range of states. - print_weights : bool - Whether to print the original and combined weights for each replica. """ - if print_weights is True: + if print_values is True: w = np.round(weights, decimals=3).tolist() # just for printing print(' Original weights:') for i in range(len(w)): print(f' Rep {i}: {w[i]}') + print('\n Original histogram counts:') + for i in range(len(hist)): + print(f' Rep {i}: {hist[i]}') # Calculate adjacent weight differences and g_vec - dg_vec = [] + dg_vec, N_ratio_vec = [], [] # alchemical weight differences and histogram count ratios for the whole range dg_adjacent = [list(np.diff(weights[i])) for i in range(len(weights))] + # Suppress the specific warning here + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + N_ratio_adjacent = [list(np.array(hist[i][1:]) / np.array(hist[i][:-1])) for i in range(len(hist))] + + # Below we deal with the case where the sampling is poor or the WL incrementor just got updated such that + # the histogram counts are 0 for some states, in which case we simply skip histogram correction. + contains_nan = any(np.isnan(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by 0/0 # noqa: E501 + contains_inf = any(np.isinf(value) for sublist in N_ratio_adjacent for value in sublist) # can be caused by x/0, where x is a finite number # noqa: E501 + skip_hist_correction = contains_nan or contains_inf + if skip_hist_correction: + print('\n Histogram correction is skipped because the histogram counts are 0 for some states.') + if weights_err is not None: dg_adjacent_err = [[np.sqrt(weights_err[i][j] ** 2 + weights_err[i][j + 1] ** 2) for j in range(len(weights_err[i]) - 1)] for i in range(len(weights_err))] # noqa: E501 for i in range(self.n_tot - 1): - dg_list, dg_err_list = [], [] + dg_list, dg_err_list, N_ratio_list = [], [], [] for j in range(len(self.state_ranges)): if i in self.state_ranges[j] and i + 1 in self.state_ranges[j]: idx = self.state_ranges[j].index(i) dg_list.append(dg_adjacent[j][idx]) + N_ratio_list.append(N_ratio_adjacent[j][idx]) if weights_err is not None: dg_err_list.append(dg_adjacent_err[j][idx]) if weights_err is None: dg_vec.append(np.mean(dg_list)) else: dg_vec.append(utils.weighted_mean(dg_list, dg_err_list)[0]) + N_ratio_vec.append(np.prod(N_ratio_list) ** (1 / len(N_ratio_list))) # geometric mean dg_vec.insert(0, 0) + N_ratio_vec.insert(0, hist[0][0]) g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))]) - - # Determine the vector of alchemical weights for each replica + if skip_hist_correction is False: + # When skip_hist_correction is True, previous lines for calculating N_ratio_vec or N_ratio_list will + # still not error out so it's fine to not add the conditional statement like here, since we will + # have hist_modified = hist at the end anyway. However, if skip_hist_correction, things like + # int(np.nan) will lead to an error, so we put an if condition here. + N_vec = np.array([int(np.ceil(np.prod(N_ratio_vec[:(i + 1)]))) for i in range(len(N_ratio_vec))]) + + # Determine the vector of alchemical weights and histogram counts for each replica weights_modified = np.zeros_like(weights) for i in range(self.n_sim): + hist_modified = [] if self.equil[i] == -1: # unequilibrated weights_modified[i] = list(g_vec[i * self.s: i * self.s + self.n_sub] - g_vec[i * self.s: i * self.s + self.n_sub][0]) # noqa: E501 else: weights_modified[i] = self.equilibrated_weights[i] + if skip_hist_correction is False: + hist_modified = [list(N_vec[self.state_ranges[i]]) for i in range(self.n_sim)] + else: + hist_modified = hist - if print_weights is True: + if print_values is True: w = np.round(weights_modified, decimals=3).tolist() # just for printing print('\n Modified weights:') for i in range(len(w)): print(f' Rep {i}: {w[i]}') + if skip_hist_correction is False: + print('\n Modified histogram counts:') + for i in range(len(hist_modified)): + print(f' Rep {i}: {hist_modified[i]}') if self.verbose is False: print(' DONE') @@ -1284,7 +1326,7 @@ def combine_weights(self, weights, weights_err=None, print_weights=True): else: print(f'\n The alchemical weights of all states: \n {list(np.round(g_vec, decimals=3))}') - return weights_modified, g_vec + return hist_modified, weights_modified, g_vec def _run_grompp(self, n, swap_pattern): """ diff --git a/ensemble_md/tests/test_ensemble_EXE.py b/ensemble_md/tests/test_ensemble_EXE.py index ca5656d0..a832515b 100644 --- a/ensemble_md/tests/test_ensemble_EXE.py +++ b/ensemble_md/tests/test_ensemble_EXE.py @@ -99,7 +99,7 @@ def test_set_params_error(self, params_dict): # 5. Non-negative parameters check_param_error(params_dict, 'n_ex', "The parameter 'n_ex' should be non-negative.", -1) - check_param_error(params_dict, 'N_cutoff', "The parameter 'N_cutoff' should be non-negative unless no histogram correction is needed, i.e. N_cutoff = -1.", -5) # noqa: E501 + check_param_error(params_dict, 'N_cutoff', "The parameter 'N_cutoff' should be non-negative unless no weight correction is needed, i.e. N_cutoff = -1.", -5) # noqa: E501 # 6. String parameters check_param_error(params_dict, 'mdp', "The parameter 'mdp' should be a string.", 3, 'ensemble_md/tests/data/expanded.mdp') # noqa: E501 @@ -144,7 +144,7 @@ def test_set_params_warnings(self, params_dict): params_dict['N_cutoff'] = 1000 EEXE = get_EEXE_instance(params_dict) - warning_1 = 'Warning: The histogram correction/weight combination method is specified but will not be used since the weights are fixed.' # noqa: E501 + warning_1 = 'Warning: The weight correction/weight combination method is specified but will not be used since the weights are fixed.' # noqa: E501 warning_2 = 'Warning: We recommend setting lmc_seed as -1 so the random seed is different for each iteration.' warning_3 = 'Warning: We recommend setting gen_seed as -1 so the random seed is different for each iteration.' assert warning_1 in EEXE.warnings @@ -580,15 +580,15 @@ def test_accept_or_reject(self, params_dict): assert swap_bool_2 is False assert swap_bool_3 is True - def test_historgam_correction(self, params_dict): + def test_weight_correction(self, params_dict): EEXE = get_EEXE_instance(params_dict) - # Case 1: Perform histogram correction (N_cutoff reached) + # Case 1: Perform weight correction (N_cutoff reached) EEXE.N_cutoff = 5000 EEXE.verbose = False # just to increase code coverage weights_1 = [[0, 10.304, 20.073, 29.364]] counts_1 = [[31415, 45701, 55457, 59557]] - weights_1 = EEXE.histogram_correction(weights_1, counts_1) + weights_1 = EEXE.weight_correction(weights_1, counts_1) assert np.allclose(weights_1, [ [ 0, @@ -598,14 +598,17 @@ def test_historgam_correction(self, params_dict): ] ]) # noqa: E501 - # Case 2: Perform histogram correction (N_cutoff not reached by both N_k and N_{k-1}) + # Case 2: Perform weight correction (N_cutoff not reached by both N_k and N_{k-1}) EEXE.verbose = True weights_2 = [[0, 10.304, 20.073, 29.364]] counts_2 = [[3141, 4570, 5545, 5955]] - weights_2 = EEXE.histogram_correction(weights_2, counts_2) + weights_2 = EEXE.weight_correction(weights_2, counts_2) assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]]) - def test_combine_weights(self, params_dict): + def test_combine_weights_1(self, params_dict): + """ + Here we just test the combined weights, so the values of hist does not matter. + """ EEXE = get_EEXE_instance(params_dict) EEXE.n_tot = 6 EEXE.n_sub = 4 @@ -613,9 +616,9 @@ def test_combine_weights(self, params_dict): EEXE.n_sim = 3 EEXE.state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]] weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] + hist = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]] - EEXE.w_combine = True - w_1, g_vec_1 = EEXE.combine_weights(weights) + _, w_1, g_vec_1 = EEXE.combine_weights(hist, weights) assert np.allclose(w_1, [ [0, 2.1, 3.9, 3.5], [0, 1.8, 1.4, 2.75], @@ -624,9 +627,25 @@ def test_combine_weights(self, params_dict): weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] errors = [[0, 0.1, 0.15, 0.1], [0, 0.12, 0.1, 0.12], [0, 0.12, 0.15, 0.1]] - w_2, g_vec_2 = EEXE.combine_weights(weights, errors) + _, w_2, g_vec_2 = EEXE.combine_weights(hist, weights, errors) assert np.allclose(w_2, [ [0, 2.1, 3.86140725, 3.45417313], [0, 1.76140725, 1.35417313, 2.71436889], [0, -0.40723412, 0.95296164, 1.95296164]]) assert np.allclose(list(g_vec_2), [0, 2.1, 3.861407249466951, 3.4541731330165306, 4.814368891580968, 5.814368891580968]) # noqa: E501 + + def test_combine_weights_2(self, params_dict): + """ + Here we just test the modified histograms, so the values of weights does not matter. + """ + EEXE = get_EEXE_instance(params_dict) + EEXE.n_tot = 6 + EEXE.n_sub = 5 + EEXE.s = 1 + EEXE.n_sim = 2 + EEXE.state_ranges = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]] + weights = [[0, 2.1, 4.0, 3.7, 5], [0, 1.7, 1.2, 2.6, 4]] + hist = [[416, 332, 130, 71, 61], [303, 181, 123, 143, 260]] + + hist_modified, _, _ = EEXE.combine_weights(hist, weights) + assert hist_modified == [[416, 332, 161, 98, 98], [332, 161, 98, 98, 178]]