diff --git a/pyproject.toml b/pyproject.toml index 0b8bf5c..a6959af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "znh5md" -version = "0.4.1" +version = "0.4.2" description = "ASE Interface for the H5MD format." authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/test_single_obs_key.py b/tests/test_single_obs_key.py new file mode 100644 index 0000000..84a5afe --- /dev/null +++ b/tests/test_single_obs_key.py @@ -0,0 +1,66 @@ +import numpy.testing as npt +from ase.build import molecule +from ase.calculators.singlepoint import SinglePointCalculator + +import znh5md +import znh5md.serialization + + +def test_single_entry_info(tmp_path): + # Test a special case where only the first config has the key + # which caused an error in the past + io = znh5md.IO(tmp_path / "test.h5") + water = molecule("H2O") + water.info["density"] = 0.997 + io.append(water) + del water.info["density"] + io.extend([water for _ in range(5)]) + assert len(io) == 6 + assert len(list(io)) == 6 + assert len(io[:]) == 6 + assert io[0].info["density"] == 0.997 + assert "density" not in io[1].info + + frames = znh5md.serialization.Frames.from_ase(list(io)) + assert len(frames) == 6 + assert len(list(frames)) == 6 + + +def test_single_entry_arrays(tmp_path): + # Test a special case where only the first config has the key + # which caused an error in the past + io = znh5md.IO(tmp_path / "test.h5") + water = molecule("H2O") + water.arrays["density"] = [0.997, 0.998, 0.999] + io.append(water) + del water.arrays["density"] + io.extend([water for _ in range(5)]) + assert len(io) == 6 + assert len(list(io)) == 6 + assert len(io[:]) == 6 + npt.assert_array_equal(io[0].arrays["density"], [0.997, 0.998, 0.999]) + assert "density" not in io[1].arrays + + frames = znh5md.serialization.Frames.from_ase(list(io)) + assert len(frames) == 6 + assert len(list(frames)) == 6 + + +def test_single_entry_calc(tmp_path): + # Test a special case where only the first config has the key + # which caused an error in the past + io = znh5md.IO(tmp_path / "test.h5") + water = molecule("H2O") + water.calc = SinglePointCalculator(water, energy=0.0, forces=[0.0, 0.0, 0.0]) + io.append(water) + water.calc = None + io.extend([water for _ in range(5)]) + assert len(io) == 6 + assert len(list(io)) == 6 + assert len(io[:]) == 6 + assert io[0].calc.results["energy"] == 0.0 + assert io[1].calc is None + + frames = znh5md.serialization.Frames.from_ase(list(io)) + assert len(frames) == 6 + assert len(list(frames)) == 6 diff --git a/tests/test_znh5md.py b/tests/test_znh5md.py index e800b4d..185200d 100644 --- a/tests/test_znh5md.py +++ b/tests/test_znh5md.py @@ -2,7 +2,7 @@ def test_version(): - assert znh5md.__version__ == "0.4.1" + assert znh5md.__version__ == "0.4.2" def test_creator(tmp_path): diff --git a/znh5md/interface/read.py b/znh5md/interface/read.py index ba9eee4..235903f 100644 --- a/znh5md/interface/read.py +++ b/znh5md/interface/read.py @@ -329,17 +329,18 @@ def process_observables(self, frames: Frames, observables, index) -> None: origin = grp.attrs.get(AttributePath.origin.value, None) try: try: - update_frames( - frames, - H5MDToASEMapping[grp_name].value, - grp["value"][index], - origin, - self.use_ase_calc, - ) - except KeyError: - update_frames( - frames, grp_name, grp["value"][index], origin, self.use_ase_calc - ) + try: + update_frames( + frames, + H5MDToASEMapping[grp_name].value, + grp["value"][index], + origin, + self.use_ase_calc, + ) + except KeyError: + update_frames( + frames, grp_name, grp["value"][index], origin, self.use_ase_calc + ) except (OSError, IndexError): pass # Handle backfilling for invalid values except KeyError: diff --git a/znh5md/serialization/base.py b/znh5md/serialization/base.py index ce47cdf..d5e0504 100644 --- a/znh5md/serialization/base.py +++ b/znh5md/serialization/base.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import functools import json @@ -238,28 +239,33 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> ase.Atoms: """Return a single frame.""" + # this raises the IndexError to determine the length of the Frames object atoms = ase.Atoms( numbers=self.numbers[idx], positions=self.positions[idx], cell=self.cell[idx], pbc=self.pbc[idx], ) + # all data following here can be missing for key in self.arrays: - if isinstance(self.arrays[key][idx], _MISSING): - continue - if key == "velocities": - atoms.set_velocities(self.arrays[key][idx]) - else: - atoms.arrays[key] = self.arrays[key][idx] + with contextlib.suppress(IndexError): + if isinstance(self.arrays[key][idx], _MISSING): + continue + if key == "velocities": + atoms.set_velocities(self.arrays[key][idx]) + else: + atoms.arrays[key] = self.arrays[key][idx] for key in self.info: - if not isinstance(self.info[key][idx], _MISSING): - atoms.info[key] = self.info[key][idx] + with contextlib.suppress(IndexError): + if not isinstance(self.info[key][idx], _MISSING): + atoms.info[key] = self.info[key][idx] for key in self.calc: - if not isinstance(self.calc[key][idx], _MISSING): - if atoms.calc is None: - atoms.calc = SinglePointCalculator(atoms) - atoms.calc.results[key] = self.calc[key][idx] + with contextlib.suppress(IndexError): + if not isinstance(self.calc[key][idx], _MISSING): + if atoms.calc is None: + atoms.calc = SinglePointCalculator(atoms) + atoms.calc.results[key] = self.calc[key][idx] return atoms