Skip to content

Commit

Permalink
Merge pull request #719 from OpenFreeEnergy/fix-chkpoint
Browse files Browse the repository at this point in the history
Fix checkpointing
  • Loading branch information
richardjgowers authored Feb 13, 2024
2 parents 8625192 + b43d667 commit 9ebbcbe
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 83 deletions.
13 changes: 12 additions & 1 deletion openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def _get_reporter(
self,
topology: app.Topology,
positions: openmm.unit.Quantity,
simulation_settings: MultiStateSimulationSettings,
output_settings: OutputSettings,
) -> multistate.MultiStateReporter:
"""
Expand All @@ -570,6 +571,11 @@ def _get_reporter(
----------
topology : app.Topology
A Topology of the system being created.
positions : openmm.unit.Quantity
Positions of the pre-alchemical simulation system.
simulation_settings : MultiStateSimulationSettings
Multistate simulation control settings, specifically containing
the amount of time per state sampling iteration.
output_settings: OutputSettings
Output settings for the simulations
Expand All @@ -586,11 +592,15 @@ def _get_reporter(

nc = self.shared_basepath / output_settings.output_filename
chk = output_settings.checkpoint_storage_filename
chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations(
checkpoint_interval=output_settings.checkpoint_interval,
time_per_iteration=simulation_settings.time_per_iteration,
)

reporter = multistate.MultiStateReporter(
storage=nc,
analysis_particle_indices=selection_indices,
checkpoint_interval=output_settings.checkpoint_interval.m,
checkpoint_interval=chk_intervals,
checkpoint_storage=chk,
)

Expand Down Expand Up @@ -914,6 +924,7 @@ def run(self, dry=False, verbose=True,
# 11. Create the multistate reporter & create PDB
reporter = self._get_reporter(
omm_topology, positions,
settings['simulation_settings'],
settings['output_settings'],
)

Expand Down
59 changes: 37 additions & 22 deletions openfe/protocols/openmm_md/plain_md_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,14 @@ def _run_MD(simulation: openmm.app.Simulation,
positions: omm_unit.Quantity,
simulation_settings: MDSimulationSettings,
output_settings: MDOutputSettings,
temperature: settings.ThermoSettings.temperature,
barostat_frequency: IntegratorSettings.barostat_frequency,
temperature: unit.Quantity,
barostat_frequency: unit.Quantity,
timestep: unit.Quantity,
equil_steps_nvt: int,
equil_steps_npt: int,
prod_steps: int,
verbose=True,
shared_basepath=None):
shared_basepath=None) -> None:

"""
Energy minimization, Equilibration and Production MD to be reused
Expand All @@ -275,10 +276,12 @@ def _run_MD(simulation: openmm.app.Simulation,
Settings for MD simulation
output_settings: OutputSettingsMD
Settings for output of MD simulation
temperature: settings.ThermoSettings.temperature
temperature: FloatQuantity["kelvin"]
temperature setting
barostat_frequency: IntegratorSettings.barostat_frequency
barostat_frequency: unit.Quantity
Frequency for the barostat
timestep: FloatQuantity["femtosecond"]
Simulation integration timestep
equil_steps_nvt: int
number of steps for NVT equilibration
equil_steps_npt: int
Expand All @@ -291,9 +294,6 @@ def _run_MD(simulation: openmm.app.Simulation,
shared_basepath : Pathlike, optional
Where to run the calculation, defaults to current working directory
Returns
-------
"""
if shared_basepath is None:
shared_basepath = pathlib.Path('.')
Expand Down Expand Up @@ -396,16 +396,29 @@ def _run_MD(simulation: openmm.app.Simulation,
logger.info("running production phase")

# Setup the reporters
write_interval = settings_validation.divmod_time_and_check(
output_settings.trajectory_write_interval,
timestep,
"trajectory_write_interval",
"timestep",
)

checkpoint_interval = settings_validation.get_simsteps(
sim_length=output_settings.checkpoint_interval,
timestep=timestep,
mc_steps=1,
)

simulation.reporters.append(XTCReporter(
file=str(shared_basepath / output_settings.production_trajectory_filename),
reportInterval=output_settings.trajectory_write_interval.m,
reportInterval=write_interval,
atomSubset=selection_indices))
simulation.reporters.append(openmm.app.CheckpointReporter(
file=str(shared_basepath / output_settings.checkpoint_storage_filename),
reportInterval=output_settings.checkpoint_interval.m))
reportInterval=checkpoint_interval))
simulation.reporters.append(openmm.app.StateDataReporter(
str(shared_basepath / output_settings.log_output),
output_settings.checkpoint_interval.m,
checkpoint_interval,
step=True,
time=True,
potentialEnergy=True,
Expand Down Expand Up @@ -597,17 +610,19 @@ def run(self, *, dry=False, verbose=True,
try:

if not dry: # pragma: no-cover
self._run_MD(simulation,
stateA_positions,
sim_settings,
output_settings,
thermo_settings.temperature,
integrator_settings.barostat_frequency,
equil_steps_nvt,
equil_steps_npt,
prod_steps,
shared_basepath=shared_basepath,
)
self._run_MD(
simulation,
stateA_positions,
sim_settings,
output_settings,
thermo_settings.temperature,
integrator_settings.barostat_frequency,
timestep,
equil_steps_nvt,
equil_steps_npt,
prod_steps,
shared_basepath=shared_basepath,
)

finally:

Expand Down
12 changes: 7 additions & 5 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,16 +858,18 @@ def run(self, *, dry=False, verbose=True,
)

# a. Create the multistate reporter
# convert checkpoint_interval from time to steps
checkpoint_fs = output_settings.checkpoint_interval.to(unit.femtosecond).m
ts_fs = integrator_settings.timestep.to(unit.femtosecond).m
checkpoint_int = int(round(checkpoint_fs / ts_fs))
# convert checkpoint_interval from time to iterations
chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations(
checkpoint_interval=output_settings.checkpoint_interval,
time_per_iteration=sampler_settings.time_per_iteration,
)

nc = shared_basepath / output_settings.output_filename
chk = output_settings.checkpoint_storage_filename
reporter = multistate.MultiStateReporter(
storage=nc,
analysis_particle_indices=selection_indices,
checkpoint_interval=checkpoint_int,
checkpoint_interval=chk_intervals,
checkpoint_storage=chk,
)

Expand Down
2 changes: 1 addition & 1 deletion openfe/protocols/openmm_utils/omm_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class Config:
# reporter settings
production_trajectory_filename = 'simulation.xtc'
"""Path to the storage file for analysis. Default 'simulation.xtc'."""
trajectory_write_interval = 5000 * unit.timestep
trajectory_write_interval: FloatQuantity['picosecond'] = 20 * unit.picosecond
"""
Frequency to write the xtc file. Default 5000 * unit.timestep.
"""
Expand Down
149 changes: 117 additions & 32 deletions openfe/protocols/openmm_utils/settings_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,103 @@ def get_simsteps(sim_length: unit.Quantity,
return sim_steps


def divmod_time(
time: unit.Quantity,
time_per_iteration: unit.Quantity,
) -> tuple[int, int]:
"""
Convert a set amount of time to a number of iterations.
Parameters
---------
time: unit.Quantity
The time to convert.
time_per_iteration : unit.Quantity
The amount of time which each iteration takes.
Returns
-------
iterations : int
The number of iterations covered by the input time.
remainder : int
The remainder of the input time and time_per_iteration division.
"""
time_ats = round(time.to(unit.attosecond).m)
tpi_ats = round(time_per_iteration.to(unit.attosecond).m)

iterations, remainder = divmod(time_ats, tpi_ats)

return iterations, remainder


def divmod_time_and_check(numerator: unit.Quantity, denominator: unit.Quantity,
numerator_name: str, denominator_name: str) -> int:
"""Perform a division of time, failing if there is a remainder
For example numerator 20.0 ps and denominator 4.0 fs gives 5000
Parameters
----------
numerator, denominator : unit.Quantity
the division to perform
numerator_name, denominator_name : str
used for the error generated if there is any remainder
Returns
-------
iterations : int
the result of the division
Raises
------
ValueError
if the division results in any remainder, will include a formatted error
message
"""
its, rem = divmod_time(numerator, denominator)

if rem:
errmsg = (f"The {numerator_name} ({numerator}) "
"does not evenly divide by the "
f"{denominator_name} ({denominator})")
raise ValueError(errmsg)

return its


def convert_checkpoint_interval_to_iterations(
checkpoint_interval: unit.Quantity,
time_per_iteration: unit.Quantity,
) -> int:
"""
Get the number of iterations per checkpoint interval.
This is necessary as our input settings define checkpoints intervals in
units of time, but OpenMMTools' MultiStateReporter requires them defined
in the number of MC intervals.
Parameters
----------
checkpoint_interval : unit.Quantity
The amount of time per checkpoints written.
time_per_iteration : unit.Quantity
The amount of time each MC iteration takes.
Returns
-------
iterations : int
The number of iterations per checkpoint.
"""
return divmod_time_and_check(
checkpoint_interval, time_per_iteration,
"amount of time per checkpoint",
"amount of time per state MCM move attempt"
)


def convert_steps_per_iteration(
simulation_settings: MultiStateSimulationSettings,
integrator_settings: IntegratorSettings,
simulation_settings: MultiStateSimulationSettings,
integrator_settings: IntegratorSettings,
) -> int:
"""Convert time per iteration to steps
Expand All @@ -89,19 +183,16 @@ def convert_steps_per_iteration(
steps_per_iteration : int
suitable for input to Integrator
"""
tpi_fs = round(simulation_settings.time_per_iteration.to(unit.attosecond).m)
ts_fs = round(integrator_settings.timestep.to(unit.attosecond).m)
steps_per_iteration, rem = divmod(tpi_fs, ts_fs)

if rem:
raise ValueError(f"time_per_iteration ({simulation_settings.time_per_iteration}) "
f"not divisible by timestep ({integrator_settings.timestep})")

return steps_per_iteration
return divmod_time_and_check(
simulation_settings.time_per_iteration,
integrator_settings.timestep,
"time_per_iteration",
"timestep",
)


def convert_real_time_analysis_iterations(
simulation_settings: MultiStateSimulationSettings,
simulation_settings: MultiStateSimulationSettings,
) -> tuple[Optional[int], Optional[int]]:
"""Convert time units in Settings to various other units
Expand All @@ -127,31 +218,25 @@ def convert_real_time_analysis_iterations(
# option to turn off real time analysis
return None, None

tpi_fs = round(simulation_settings.time_per_iteration.to(unit.attosecond).m)

# convert real_time_analysis time to interval
# rta_its must be number of MCMC iterations
# i.e. rta_fs / tpi_fs -> number of iterations
rta_fs = round(simulation_settings.real_time_analysis_interval.to(unit.attosecond).m)

rta_its, rem = divmod(rta_fs, tpi_fs)
if rem:
raise ValueError(f"real_time_analysis_interval ({simulation_settings.real_time_analysis_interval}) "
f"is not divisible by time_per_iteration ({simulation_settings.time_per_iteration})")

# convert RTA_minimum_time to iterations
rta_min_fs = round(simulation_settings.real_time_analysis_minimum_time.to(unit.attosecond).m)
rta_min_its, rem = divmod(rta_min_fs, tpi_fs)
if rem:
raise ValueError(f"real_time_analysis_minimum_time ({simulation_settings.real_time_analysis_minimum_time}) "
f"is not divisible by time_per_iteration ({simulation_settings.time_per_iteration})")
rta_its = divmod_time_and_check(
simulation_settings.real_time_analysis_interval,
simulation_settings.time_per_iteration,
"real_time_analysis_interval",
"time_per_iteration",
)
rta_min_its = divmod_time_and_check(
simulation_settings.real_time_analysis_minimum_time,
simulation_settings.time_per_iteration,
"real_time_analysis_minimum_time",
"time_per_iteration",
)

return rta_its, rta_min_its


def convert_target_error_from_kcal_per_mole_to_kT(
temperature,
target_error,
temperature,
target_error,
) -> float:
"""Convert kcal/mol target error to kT units
Expand Down
12 changes: 7 additions & 5 deletions openfe/tests/protocols/test_openmm_afe_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,20 @@ def test_openmm_run_engine(platform,

# Run a really short calculation to check everything is going well
s = openmm_afe.AbsoluteSolvationProtocol.default_settings()
s.alchemsampler_settings.n_repeats = 1
s.protocol_repeats = 1
s.solvent_output_settings.output_indices = "resname UNK"
s.vacuum_simulation_settings.equilibration_length = 0.1 * unit.picosecond
s.vacuum_simulation_settings.production_length = 0.1 * unit.picosecond
s.solvent_simulation_settings.equilibration_length = 0.1 * unit.picosecond
s.solvent_simulation_settings.production_length = 0.1 * unit.picosecond
s.vacuum_engine_settings.compute_platform = platform
s.solvent_engine_settings.compute_platform = platform
s.alchemsampler_settings.steps_per_iteration = 5 * unit.timestep
s.vacuum_output_settings.checkpoint_interval = 5 * unit.timestep
s.solvent_output_settings.checkpoint_interval = 5 * unit.timestep
s.alchemsampler_settings.n_replicas = 20
s.vacuum_simulation_settings.time_per_iteration = 20 * unit.femtosecond
s.solvent_simulation_settings.time_per_iteration = 20 * unit.femtosecond
s.vacuum_output_settings.checkpoint_interval = 20 * unit.femtosecond
s.solvent_output_settings.checkpoint_interval = 20 * unit.femtosecond
s.vacuum_simulation_settings.n_replicas = 20
s.solvent_simulation_settings.n_replicas = 20
s.lambda_settings.lambda_elec = \
[0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Expand Down
Loading

0 comments on commit 9ebbcbe

Please sign in to comment.