Skip to content

Commit

Permalink
Fix protein mutation repex consistency tests (#1054)
Browse files Browse the repository at this point in the history
Co-authored-by: Iván Pulido <ivanpulido@protonmail.com>
  • Loading branch information
zhang-ivy and ijpulidos authored Jun 24, 2022
1 parent 6af4bfd commit 9e61c22
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/self-hosted-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ jobs:
shell: bash -l {0}
run: |
export TRAVIS=true
pytest -v --cov-report xml --cov=perses --durations=0 -a "not advanced" -m "gpu_ci or gpu_needed" -k "not test_RESTCapableHybridTopologyFactory_repex_neutral_mutation and not test_RESTCapableHybridTopologyFactory_repex_charge_mutation" perses/tests
pytest -v --cov-report xml --cov=perses --durations=0 -a "not advanced" -m "gpu_ci or gpu_needed" perses/tests
- name: Codecov
if: ${{ github.repository == 'choderalab/perses'
Expand Down
58 changes: 29 additions & 29 deletions perses/tests/test_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import pytest
from perses.tests.utils import enter_temp_directory

# Tolerance for energy differences tests
ENERGY_DIFF_TOLERANCE_KT = 0.5


@pytest.mark.gpu_needed
def test_RESTCapableHybridTopologyFactory_repex_neutral_mutation():
"""
Run ALA->THR and THR->ALA repex with the RESTCapableHybridTopologyFactory and make sure that the free energies are
equal and opposite.
Note: We are using 50 steps per iteration here to speed up the test. We expect larger DDGs and dDDGs as as result.
"""

from pkg_resources import resource_filename
Expand All @@ -28,13 +26,13 @@ def test_RESTCapableHybridTopologyFactory_repex_neutral_mutation():
platform = configure_platform(utils.get_fastest_platform().getName())

data = {}
n_iterations = 250
n_iterations = 3000
mutations = [('ala', 'thr'), ('thr', 'ala')]
with enter_temp_directory() as temp_dir:
for wt_name, mutant_name in mutations:
# Generate htf
pdb_filename = resource_filename("perses", f"data/{wt_name}_vacuum.pdb")
solvent_delivery = PointMutationExecutor(
solvent_delivery = PointMutationExecutor( # TODO: Need to be specify larger padding (1.7 nm) to work with openmm >= 7.8
pdb_filename,
"1",
"2",
Expand All @@ -54,20 +52,20 @@ def test_RESTCapableHybridTopologyFactory_repex_neutral_mutation():
nonbonded_force.setUseDispersionCorrection(True)

# Set up repex simulation
reporter_file = os.path.join(temp_dir, f"{wt_name}-{mutant_name}")
reporter_file = os.path.join(temp_dir, f"{wt_name}-{mutant_name}.nc")
reporter = MultiStateReporter(reporter_file, checkpoint_interval=10)
hss = HybridRepexSampler(mcmc_moves=mcmc.LangevinSplittingDynamicsMove(timestep=4.0 * unit.femtoseconds,
collision_rate=1.0 / unit.picosecond,
n_steps=250,
reassign_velocities=True,
n_steps=50,
reassign_velocities=False,
n_restart_attempts=20,
splitting="V R R R O R R R V",
constraint_tolerance=1e-06),
replica_mixing_scheme='swap-all',
hybrid_factory=htf,
online_analysis_interval=None)
hss.setup(n_states=12, temperature=300 * unit.kelvin, t_max=300 * unit.kelvin,
storage_file=reporter, endstates=True)
storage_file=reporter, minimisation_steps=0, endstates=True)
hss.energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform)
hss.sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform)

Expand All @@ -81,12 +79,12 @@ def test_RESTCapableHybridTopologyFactory_repex_neutral_mutation():
f_ij, df_ij = analyzer.get_free_energy()
data[f"{wt_name}-{mutant_name}"] = {'free_energy': f_ij[0, -1], 'error': df_ij[0, -1]}

forward_data = data['ala-thr']['free_energy']
reverse_data = data['thr-ala']['free_energy'] * -1
assert np.isclose([forward_data], [reverse_data], atol=ENERGY_DIFF_TOLERANCE_KT), \
f"ALA-THR is {forward_data}. THR-ALA is {reverse_data}."
DDG = abs(data['ala-thr']['free_energy'] - data['thr-ala']['free_energy'] * -1)
dDDG = np.sqrt(data['ala-thr']['error'] ** 2 + data['thr-ala']['error'] ** 2)
assert DDG < 6 * dDDG, f"DDG ({DDG}) is greater than 6 * dDDG ({6 * dDDG})"


@pytest.mark.skip(reason="Currently taking too long in CI.")
@pytest.mark.gpu_needed
def test_RESTCapableHybridTopologyFactory_repex_charge_mutation():
"""
Expand All @@ -101,6 +99,9 @@ def test_RESTCapableHybridTopologyFactory_repex_charge_mutation():
the use of solvated PDBs as input (vs the neutral mutation test uses vacuum PDBs and requires the PointMutationExecutor to solvate).
This difference is because the ARG and LYS dipeptide PDBs were generated using the geometry engine and were therefore clashy,
so we needed to run equilibration before using them as inputs. The ALA and THR PDBs were not clashy.
Note: We are using 50 steps per iteration here to speed up the test. We expect larger DDGs and dDDGs as as result.
"""

import tempfile
Expand All @@ -118,15 +119,15 @@ def test_RESTCapableHybridTopologyFactory_repex_charge_mutation():
platform = configure_platform(utils.get_fastest_platform().getName())

data = {}
n_iterations = 2000
n_iterations = 3000
d_mutations = {'forward': [('arg', 'ala'), ('lys', 'ala')], 'reverse': [('ala', 'arg'), ('ala', 'lys')]}

with enter_temp_directory() as temp_dir:
for mutation_type, mutations in d_mutations.items():
for wt_name, mutant_name in mutations:
# Generate htf
pdb_filename = resource_filename("perses", f"data/{wt_name}_solvated.cif") if mutation_type == 'forward' else os.path.join(temp_dir, f"{wt_name}.cif")
solvent_delivery = PointMutationExecutor(
solvent_delivery = PointMutationExecutor( # TODO: Need to be specify larger padding (1.7 nm) to work with openmm >= 7.8
pdb_filename,
"1",
"2",
Expand All @@ -138,7 +139,7 @@ def test_RESTCapableHybridTopologyFactory_repex_charge_mutation():
)
htf = solvent_delivery.get_apo_rest_htf()

# Save the new
# Save the new positions to use for the reverse transformation
if mutation_type == 'forward':
app.PDBxFile.writeFile(htf._topology_proposal.new_topology,
htf.new_positions(htf.hybrid_positions),
Expand All @@ -158,16 +159,16 @@ def test_RESTCapableHybridTopologyFactory_repex_charge_mutation():
reporter = MultiStateReporter(reporter_file, checkpoint_interval=10)
hss = HybridRepexSampler(mcmc_moves=mcmc.LangevinSplittingDynamicsMove(timestep=4.0 * unit.femtoseconds,
collision_rate=1.0 / unit.picosecond,
n_steps=250,
reassign_velocities=True,
n_steps=50,
reassign_velocities=False,
n_restart_attempts=20,
splitting="V R R R O R R R V",
constraint_tolerance=1e-06),
replica_mixing_scheme='swap-all',
hybrid_factory=htf,
online_analysis_interval=None)
hss.setup(n_states=12, temperature=300 * unit.kelvin, t_max=300 * unit.kelvin,
storage_file=reporter, endstates=True)
hss.setup(n_states=36, temperature=300 * unit.kelvin, t_max=300 * unit.kelvin,
storage_file=reporter, minimisation_steps=0, endstates=True)
hss.energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform)
hss.sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform)

Expand All @@ -181,12 +182,11 @@ def test_RESTCapableHybridTopologyFactory_repex_charge_mutation():
f_ij, df_ij = analyzer.get_free_energy()
data[f"{wt_name}-{mutant_name}"] = {'free_energy': f_ij[0, -1], 'error': df_ij[0, -1]}

arg_ala_forward = data['arg-ala']['free_energy']
arg_ala_reverse = data['ala-arg']['free_energy']
lys_ala_forward = data['lys-ala']['free_energy']
lys_ala_reverse = data['ala-lys']['free_energy']
arg_ala_arg = arg_ala_forward + arg_ala_reverse
lys_ala_lys = lys_ala_forward + lys_ala_reverse
DDG = data['arg-ala']['free_energy'] + data['ala-arg']['free_energy'] \
- (data['lys-ala']['free_energy'] + data['ala-lys']['free_energy'])
dDDG = np.sqrt(data['arg-ala']['error'] ** 2
+ data['ala-arg']['error'] ** 2
+ data['lys-ala']['error'] ** 2
+ data['ala-lys']['error'] ** 2)

assert np.isclose([arg_ala_arg], [lys_ala_lys], atol=ENERGY_DIFF_TOLERANCE_KT), \
f"ARG-ALA-ARG is {arg_ala_arg}. LYS-ALA-LYS is {lys_ala_lys}."
assert DDG < 6 * dDDG, f"DDG ({DDG}) is greater than 6 * dDDG ({6 * dDDG})"

0 comments on commit 9e61c22

Please sign in to comment.