diff --git a/openqdc/datasets/interaction/base.py b/openqdc/datasets/interaction/base.py index 96f39c1..8a8e2ea 100644 --- a/openqdc/datasets/interaction/base.py +++ b/openqdc/datasets/interaction/base.py @@ -42,6 +42,7 @@ def __getitem__(self, idx: int): forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32)) e0 = self._convert_array(np.array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32)) + formation_energies = energies - e0.sum(axis=0) bunch = Bunch( positions=positions, @@ -49,6 +50,8 @@ def __getitem__(self, idx: int): charges=c, e0=e0, energies=energies, + formation_energies=formation_energies, + per_atom_formation_energies=formation_energies / len(z), name=name, subset=subset, forces=forces, diff --git a/openqdc/datasets/interaction/dummy.py b/openqdc/datasets/interaction/dummy.py index 48e92a9..71bf5ee 100644 --- a/openqdc/datasets/interaction/dummy.py +++ b/openqdc/datasets/interaction/dummy.py @@ -10,7 +10,7 @@ class DummyInteraction(BaseInteractionDataset): Dummy Interaction Dataset for Testing """ - __name__ = "dummy" + __name__ = "dummy_interaction" __energy_methods__ = [InteractionMethod.SAPT0_AUG_CC_PVDDZ, InteractionMethod.CCSD_T_CC_PVDZ] __force_mask__ = [False, True] __energy_unit__ = "kcal/mol" diff --git a/tests/test_dummy.py b/tests/test_dummy.py index 08ee127..a241384 100644 --- a/tests/test_dummy.py +++ b/tests/test_dummy.py @@ -5,6 +5,7 @@ import numpy as np import pytest +from openqdc.datasets.interaction.dummy import DummyInteraction # noqa: E402 from openqdc.datasets.potential.dummy import Dummy # noqa: E402 from openqdc.utils.io import get_local_cache from openqdc.utils.package_utils import has_package @@ -12,6 +13,7 @@ # start by removing any cached data cache_dir = get_local_cache() os.system(f"rm -rf {cache_dir}/dummy") +os.system(f"rm -rf {cache_dir}/dummy_interaction") if has_package("torch"): @@ -28,22 +30,30 @@ @pytest.fixture -def ds(): +def dummy(): return Dummy() -def test_dummy(ds): +@pytest.fixture +def dummy_interaction(): + return DummyInteraction() + + +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_dummy(ds, request): + ds = request.getfixturevalue(ds) assert ds is not None assert len(ds) == 9999 assert ds[100] +@pytest.mark.parametrize("interaction_ds", [False, True]) @pytest.mark.parametrize("format", ["numpy", "torch", "jax"]) -def test_array_format(format): +def test_dummy_array_format(interaction_ds, format): if not has_package(format): pytest.skip(f"{format} is not installed, skipping test") - ds = Dummy(array_format=format) + ds = DummyInteraction(array_format=format) if interaction_ds else Dummy(array_format=format) keys = [ "positions", @@ -61,13 +71,14 @@ def test_array_format(format): assert isinstance(data[key], format_to_type[format]) -def test_transform(): +@pytest.mark.parametrize("interaction_ds", [False, True]) +def test_transform(interaction_ds): def custom_fn(bunch): # create new name bunch.new_key = bunch.name + bunch.subset return bunch - ds = Dummy(transform=custom_fn) + ds = DummyInteraction(transform=custom_fn) if interaction_ds else Dummy(transform=custom_fn) data = ds[0] @@ -75,14 +86,18 @@ def custom_fn(bunch): assert data["new_key"] == data["name"] + data["subset"] -def test_get_statistics(ds): +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_get_statistics(ds, request): + ds = request.getfixturevalue(ds) stats = ds.get_statistics() keys = ["ForcesCalculatorStats", "FormationEnergyStats", "PerAtomFormationEnergyStats", "TotalEnergyStats"] assert all(k in stats for k in keys) -def test_energy_statistics_shapes(ds): +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_energy_statistics_shapes(ds, request): + ds = request.getfixturevalue(ds) stats = ds.get_statistics() num_methods = len(ds.energy_methods) @@ -100,7 +115,9 @@ def test_energy_statistics_shapes(ds): assert total_energy_stats["std"].shape == (1, num_methods) -def test_force_statistics_shapes(ds): +@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"]) +def test_force_statistics_shapes(ds, request): + ds = request.getfixturevalue(ds) stats = ds.get_statistics() num_force_methods = len(ds.force_methods) @@ -115,12 +132,13 @@ def test_force_statistics_shapes(ds): assert forces_stats["component_rms"].shape == (3, num_force_methods) +@pytest.mark.parametrize("interaction_ds", [False, True]) @pytest.mark.parametrize("format", ["numpy", "torch", "jax"]) -def test_stats_array_format(format): +def test_stats_array_format(interaction_ds, format): if not has_package(format): pytest.skip(f"{format} is not installed, skipping test") - ds = Dummy(array_format=format) + ds = DummyInteraction(array_format=format) if interaction_ds else Dummy(array_format=format) stats = ds.get_statistics() for key in stats.keys():