Skip to content

Commit

Permalink
Cache creation of compound masks
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jun 5, 2024
1 parent 347a0c0 commit 2f537e4
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 36 deletions.
79 changes: 45 additions & 34 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False):
n_compounds : int
The number of individual compounds.
"""
# Caching would help here, especially when repeating the operation
# over different frames, since these masks are coordinate-independent.
# However, cache must be invalidated whenever new compound indices are
# modified, which is not yet implemented.
# Also, should we include here the grouping for 'group', which is
# Should we include here the grouping for 'group', which is
# essentially a non-split?

cache_key = f"{compound}_masks"
compound_indices = self._get_compound_indices(compound)
compound_sizes = np.bincount(compound_indices)
size_per_atom = compound_sizes[compound_indices]
compound_sizes = compound_sizes[compound_sizes != 0]
unique_compound_sizes = unique_int_1d(compound_sizes)

# Are we already sorted? argsorting and fancy-indexing can be expensive
# so we do a quick pre-check.
needs_sorting = np.any(np.diff(compound_indices) < 0)
if needs_sorting:
# stable sort ensures reproducibility, especially concerning who
# gets to be a compound's atom[0] and be a reference for unwrap.
if stable_sort:
sort_indices = np.argsort(compound_indices, kind='stable')
else:
# Quicksort
sort_indices = np.argsort(compound_indices)
# We must sort size_per_atom accordingly (Issue #3352).
size_per_atom = size_per_atom[sort_indices]

compound_masks = []
atom_masks = []
for compound_size in unique_compound_sizes:
compound_masks.append(compound_sizes == compound_size)

# create new cache or invalidate cache when compound indices changed
if (
cache_key not in self._cache
or np.all(self._cache[cache_key]["compound_indices"]
!= compound_indices)):
compound_sizes = np.bincount(compound_indices)
size_per_atom = compound_sizes[compound_indices]
compound_sizes = compound_sizes[compound_sizes != 0]
unique_compound_sizes = unique_int_1d(compound_sizes)

# Are we already sorted? argsorting and fancy-indexing can be
# expensive so we do a quick pre-check.
needs_sorting = np.any(np.diff(compound_indices) < 0)
if needs_sorting:
atom_masks.append(sort_indices[size_per_atom == compound_size]
.reshape(-1, compound_size))
else:
atom_masks.append(np.where(size_per_atom == compound_size)[0]
.reshape(-1, compound_size))
# stable sort ensures reproducibility, especially concerning
# who gets to be a compound's atom[0] and be a reference for
# unwrap.
if stable_sort:
sort_indices = np.argsort(compound_indices, kind='stable')
else:
# Quicksort
sort_indices = np.argsort(compound_indices)
# We must sort size_per_atom accordingly (Issue #3352).
size_per_atom = size_per_atom[sort_indices]

compound_masks = []
atom_masks = []
for compound_size in unique_compound_sizes:
compound_masks.append(compound_sizes == compound_size)
if needs_sorting:
atom_masks.append(sort_indices[size_per_atom
== compound_size]
.reshape(-1, compound_size))
else:
atom_masks.append(np.where(size_per_atom
== compound_size)[0]
.reshape(-1, compound_size))

self._cache[cache_key] = {
"compound_indices": compound_indices,
"data": (atom_masks, compound_masks, len(compound_sizes))
}

return atom_masks, compound_masks, len(compound_sizes)
return self._cache[cache_key]["data"]

@warn_if_not_unique
@_pbc_to_wrap
Expand Down
39 changes: 37 additions & 2 deletions testsuite/MDAnalysisTests/core/test_accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
from numpy.testing import assert_equal, assert_almost_equal, assert_allclose

import MDAnalysis as mda
from MDAnalysis.exceptions import DuplicateWarning, NoDataError
Expand Down Expand Up @@ -99,7 +99,6 @@ def test_accumulate_array_attribute_compounds(self, name, compound, level):
ref = [np.ones((len(a), 2, 5)).sum(axis=0) for a in group.atoms.groupby(name).values()]
assert_equal(group.accumulate(np.ones((len(group.atoms), 2, 5)), compound=compound), ref)


class TestTotals(object):
"""Tests the functionality of *Group.total*() like total_mass
and total_charge.
Expand Down Expand Up @@ -291,3 +290,39 @@ def test_quadrupole_moment_fragments(self, group):
assert_almost_equal(quadrupoles,
np.array([0., 0.0011629, 0.1182701, 0.6891748
])) and len(quadrupoles) == n_compounds


class TestCache:
@pytest.fixture()
def group(self):
return mda.Universe(PSF, DCD).atoms

def test_cache(self, group):
"""Test that one cache per compound is created."""
group_nocache = group.copy()
group_cache = group.copy()

for compound in ['residues', 'fragments']:
actual = group_nocache.accumulate("masses", compound=compound)
desired = group_cache.accumulate("masses", compound=compound)

assert_allclose(actual, desired)
group_nocache._cache.pop(f"{compound}_masks")

@pytest.mark.parametrize("compound",
['residues', 'fragments'])
def test_cache_updating(self, group, compound):
"""Test caching of compound_masks for updating atomgroups."""
kwargs = {"attribute": "masses", "compound": compound}

group_nocache = group.select_atoms("prop z < 1.0", updating=True)
group_cache = group.select_atoms("prop z < 1.0", updating=True)

assert_allclose(group_nocache.accumulate(**kwargs),
group_cache.accumulate(**kwargs))

# Clear cache and forward to next frame
group_nocache._cache.pop(f"{compound}_masks")
group.universe.trajectory.next()
assert_allclose(group_nocache.accumulate(**kwargs),
group_cache.accumulate(**kwargs))

0 comments on commit 2f537e4

Please sign in to comment.