From 2f537e46a23e64f1987fe916863edcec08a494e7 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Tue, 4 Jun 2024 15:15:31 +0200 Subject: [PATCH] Cache creation of compound masks --- package/MDAnalysis/core/groups.py | 79 +++++++++++-------- .../MDAnalysisTests/core/test_accumulate.py | 39 ++++++++- 2 files changed, 82 insertions(+), 36 deletions(-) diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index 1cb57f1f342..7d6f74a055b 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False): n_compounds : int The number of individual compounds. """ - # Caching would help here, especially when repeating the operation - # over different frames, since these masks are coordinate-independent. - # However, cache must be invalidated whenever new compound indices are - # modified, which is not yet implemented. - # Also, should we include here the grouping for 'group', which is + # Should we include here the grouping for 'group', which is # essentially a non-split? + cache_key = f"{compound}_masks" compound_indices = self._get_compound_indices(compound) - compound_sizes = np.bincount(compound_indices) - size_per_atom = compound_sizes[compound_indices] - compound_sizes = compound_sizes[compound_sizes != 0] - unique_compound_sizes = unique_int_1d(compound_sizes) - - # Are we already sorted? argsorting and fancy-indexing can be expensive - # so we do a quick pre-check. - needs_sorting = np.any(np.diff(compound_indices) < 0) - if needs_sorting: - # stable sort ensures reproducibility, especially concerning who - # gets to be a compound's atom[0] and be a reference for unwrap. - if stable_sort: - sort_indices = np.argsort(compound_indices, kind='stable') - else: - # Quicksort - sort_indices = np.argsort(compound_indices) - # We must sort size_per_atom accordingly (Issue #3352). - size_per_atom = size_per_atom[sort_indices] - - compound_masks = [] - atom_masks = [] - for compound_size in unique_compound_sizes: - compound_masks.append(compound_sizes == compound_size) + + # create new cache or invalidate cache when compound indices changed + if ( + cache_key not in self._cache + or np.all(self._cache[cache_key]["compound_indices"] + != compound_indices)): + compound_sizes = np.bincount(compound_indices) + size_per_atom = compound_sizes[compound_indices] + compound_sizes = compound_sizes[compound_sizes != 0] + unique_compound_sizes = unique_int_1d(compound_sizes) + + # Are we already sorted? argsorting and fancy-indexing can be + # expensive so we do a quick pre-check. + needs_sorting = np.any(np.diff(compound_indices) < 0) if needs_sorting: - atom_masks.append(sort_indices[size_per_atom == compound_size] - .reshape(-1, compound_size)) - else: - atom_masks.append(np.where(size_per_atom == compound_size)[0] - .reshape(-1, compound_size)) + # stable sort ensures reproducibility, especially concerning + # who gets to be a compound's atom[0] and be a reference for + # unwrap. + if stable_sort: + sort_indices = np.argsort(compound_indices, kind='stable') + else: + # Quicksort + sort_indices = np.argsort(compound_indices) + # We must sort size_per_atom accordingly (Issue #3352). + size_per_atom = size_per_atom[sort_indices] + + compound_masks = [] + atom_masks = [] + for compound_size in unique_compound_sizes: + compound_masks.append(compound_sizes == compound_size) + if needs_sorting: + atom_masks.append(sort_indices[size_per_atom + == compound_size] + .reshape(-1, compound_size)) + else: + atom_masks.append(np.where(size_per_atom + == compound_size)[0] + .reshape(-1, compound_size)) + + self._cache[cache_key] = { + "compound_indices": compound_indices, + "data": (atom_masks, compound_masks, len(compound_sizes)) + } - return atom_masks, compound_masks, len(compound_sizes) + return self._cache[cache_key]["data"] @warn_if_not_unique @_pbc_to_wrap diff --git a/testsuite/MDAnalysisTests/core/test_accumulate.py b/testsuite/MDAnalysisTests/core/test_accumulate.py index aadbeee6454..52398e08b59 100644 --- a/testsuite/MDAnalysisTests/core/test_accumulate.py +++ b/testsuite/MDAnalysisTests/core/test_accumulate.py @@ -21,7 +21,7 @@ # J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 # import numpy as np -from numpy.testing import assert_equal, assert_almost_equal +from numpy.testing import assert_equal, assert_almost_equal, assert_allclose import MDAnalysis as mda from MDAnalysis.exceptions import DuplicateWarning, NoDataError @@ -99,7 +99,6 @@ def test_accumulate_array_attribute_compounds(self, name, compound, level): ref = [np.ones((len(a), 2, 5)).sum(axis=0) for a in group.atoms.groupby(name).values()] assert_equal(group.accumulate(np.ones((len(group.atoms), 2, 5)), compound=compound), ref) - class TestTotals(object): """Tests the functionality of *Group.total*() like total_mass and total_charge. @@ -291,3 +290,39 @@ def test_quadrupole_moment_fragments(self, group): assert_almost_equal(quadrupoles, np.array([0., 0.0011629, 0.1182701, 0.6891748 ])) and len(quadrupoles) == n_compounds + + +class TestCache: + @pytest.fixture() + def group(self): + return mda.Universe(PSF, DCD).atoms + + def test_cache(self, group): + """Test that one cache per compound is created.""" + group_nocache = group.copy() + group_cache = group.copy() + + for compound in ['residues', 'fragments']: + actual = group_nocache.accumulate("masses", compound=compound) + desired = group_cache.accumulate("masses", compound=compound) + + assert_allclose(actual, desired) + group_nocache._cache.pop(f"{compound}_masks") + + @pytest.mark.parametrize("compound", + ['residues', 'fragments']) + def test_cache_updating(self, group, compound): + """Test caching of compound_masks for updating atomgroups.""" + kwargs = {"attribute": "masses", "compound": compound} + + group_nocache = group.select_atoms("prop z < 1.0", updating=True) + group_cache = group.select_atoms("prop z < 1.0", updating=True) + + assert_allclose(group_nocache.accumulate(**kwargs), + group_cache.accumulate(**kwargs)) + + # Clear cache and forward to next frame + group_nocache._cache.pop(f"{compound}_masks") + group.universe.trajectory.next() + assert_allclose(group_nocache.accumulate(**kwargs), + group_cache.accumulate(**kwargs))