Skip to content

Commit

Permalink
allow MultiStateReporter to write positions and velocities at a diffe…
Browse files Browse the repository at this point in the history
…rent frequency to energies data.
  • Loading branch information
richardjgowers committed Jul 13, 2023
1 parent b379af3 commit 4805553
Showing 1 changed file with 61 additions and 27 deletions.
88 changes: 61 additions & 27 deletions openmmtools/multistate/multistatereporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class MultiStateReporter(object):
analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple)
If specified, it will serialize positions and velocities for the specified particles, at every iteration, in the
reporter storage (.nc) file. If empty, no positions or velocities will be stored in this file for any atoms.
position_interval : int or None, default 1
the frequency at which to write positions relative to analysis information
None would prevent information being written
velocity_interval : int or None, default 1
the frequency at which to write positions relative to analysis information
None would prevent information being written
Attributes
----------
Expand All @@ -113,7 +119,10 @@ class MultiStateReporter(object):
"""
def __init__(self, storage, open_mode=None,
checkpoint_interval=50, checkpoint_storage=None,
analysis_particle_indices=()):
analysis_particle_indices=(),
position_interval=1,
velocity_interval=1,
):

# Warn that API is experimental
logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases')
Expand All @@ -136,6 +145,9 @@ def __init__(self, storage, open_mode=None,
self._checkpoint_interval = checkpoint_interval
# Cast to tuple no mater what 1-D-like input was given
self._analysis_particle_indices = tuple(analysis_particle_indices)
self._position_interval = position_interval
self._velocity_interval = velocity_interval

if open_mode is not None:
self.open(open_mode)
# TODO: Maybe we want to expose this flag to control ovrwriting/appending
Expand Down Expand Up @@ -202,6 +214,14 @@ def checkpoint_interval(self):
"""Returns the checkpoint interval"""
return self._checkpoint_interval

@property
def position_interval(self):
return self._position_interval

@property
def velocity_interval(self):
return self._velocity_interval

def storage_exists(self, skip_size=False):
"""
Check if the storage files exist on disk.
Expand Down Expand Up @@ -415,6 +435,8 @@ def _initialize_storage_file(self, ncfile, nc_name, convention):
ncfile.ConventionVersion = '0.2'
ncfile.DataUsedFor = nc_name
ncfile.CheckpointInterval = self._checkpoint_interval
ncfile.PositionInterval = self._position_interval
ncfile.VelocityInterval = self._velocity_interval

# Create and initialize the global variables
nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar')
Expand Down Expand Up @@ -1647,35 +1669,47 @@ def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: i
write_iteration = self._calculate_checkpoint_iteration(iteration)
else:
write_iteration = iteration

# write out pos/vel - if checkpointing,
# or if interval matches desired frequency
write_pos = (storage_file == 'checkpoint' or
(self._position_interval is not None
and not (write_iteration % self._position_interval)))
write_vel = (storage_file == 'checkpoint' or
(self._velocity_interval is not None
and not (write_iteration % self._velocity_interval)))

# Write the sampler state if we are on the checkpoint interval OR if told to ignore the interval
if write_iteration is not None:
# Store sampler states.
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions

# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic:
if write_pos:
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions

if write_vel:
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic and write_pos:
# Store box vectors and volume.
# Allocate whole write to memory first
box_vectors = np.zeros([n_replicas, 3, 3])
Expand Down

0 comments on commit 4805553

Please sign in to comment.