Skip to content

Commit

Permalink
Updated tests for interaction datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Shenoy committed Apr 6, 2024
1 parent 18bc79c commit 2a6e3ef
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
3 changes: 3 additions & 0 deletions openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ def __getitem__(self, idx: int):
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))

e0 = self._convert_array(np.array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32))
formation_energies = energies - e0.sum(axis=0)

bunch = Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=e0,
energies=energies,
formation_energies=formation_energies,
per_atom_formation_energies=formation_energies / len(z),
name=name,
subset=subset,
forces=forces,
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/interaction/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DummyInteraction(BaseInteractionDataset):
Dummy Interaction Dataset for Testing
"""

__name__ = "dummy"
__name__ = "dummy_interaction"
__energy_methods__ = [InteractionMethod.SAPT0_AUG_CC_PVDDZ, InteractionMethod.CCSD_T_CC_PVDZ]
__force_mask__ = [False, True]
__energy_unit__ = "kcal/mol"
Expand Down
40 changes: 29 additions & 11 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import numpy as np
import pytest

from openqdc.datasets.interaction.dummy import DummyInteraction # noqa: E402
from openqdc.datasets.potential.dummy import Dummy # noqa: E402
from openqdc.utils.io import get_local_cache
from openqdc.utils.package_utils import has_package

# start by removing any cached data
cache_dir = get_local_cache()
os.system(f"rm -rf {cache_dir}/dummy")
os.system(f"rm -rf {cache_dir}/dummy_interaction")


if has_package("torch"):
Expand All @@ -28,22 +30,30 @@


@pytest.fixture
def ds():
def dummy():
return Dummy()


def test_dummy(ds):
@pytest.fixture
def dummy_interaction():
return DummyInteraction()


@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"])
def test_dummy(ds, request):
ds = request.getfixturevalue(ds)
assert ds is not None
assert len(ds) == 9999
assert ds[100]


@pytest.mark.parametrize("interaction_ds", [False, True])
@pytest.mark.parametrize("format", ["numpy", "torch", "jax"])
def test_array_format(format):
def test_dummy_array_format(interaction_ds, format):
if not has_package(format):
pytest.skip(f"{format} is not installed, skipping test")

ds = Dummy(array_format=format)
ds = DummyInteraction(array_format=format) if interaction_ds else Dummy(array_format=format)

keys = [
"positions",
Expand All @@ -61,28 +71,33 @@ def test_array_format(format):
assert isinstance(data[key], format_to_type[format])


def test_transform():
@pytest.mark.parametrize("interaction_ds", [False, True])
def test_transform(interaction_ds):
def custom_fn(bunch):
# create new name
bunch.new_key = bunch.name + bunch.subset
return bunch

ds = Dummy(transform=custom_fn)
ds = DummyInteraction(transform=custom_fn) if interaction_ds else Dummy(transform=custom_fn)

data = ds[0]

assert "new_key" in data
assert data["new_key"] == data["name"] + data["subset"]


def test_get_statistics(ds):
@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"])
def test_get_statistics(ds, request):
ds = request.getfixturevalue(ds)
stats = ds.get_statistics()

keys = ["ForcesCalculatorStats", "FormationEnergyStats", "PerAtomFormationEnergyStats", "TotalEnergyStats"]
assert all(k in stats for k in keys)


def test_energy_statistics_shapes(ds):
@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"])
def test_energy_statistics_shapes(ds, request):
ds = request.getfixturevalue(ds)
stats = ds.get_statistics()

num_methods = len(ds.energy_methods)
Expand All @@ -100,7 +115,9 @@ def test_energy_statistics_shapes(ds):
assert total_energy_stats["std"].shape == (1, num_methods)


def test_force_statistics_shapes(ds):
@pytest.mark.parametrize("ds", ["dummy", "dummy_interaction"])
def test_force_statistics_shapes(ds, request):
ds = request.getfixturevalue(ds)
stats = ds.get_statistics()
num_force_methods = len(ds.force_methods)

Expand All @@ -115,12 +132,13 @@ def test_force_statistics_shapes(ds):
assert forces_stats["component_rms"].shape == (3, num_force_methods)


@pytest.mark.parametrize("interaction_ds", [False, True])
@pytest.mark.parametrize("format", ["numpy", "torch", "jax"])
def test_stats_array_format(format):
def test_stats_array_format(interaction_ds, format):
if not has_package(format):
pytest.skip(f"{format} is not installed, skipping test")

ds = Dummy(array_format=format)
ds = DummyInteraction(array_format=format) if interaction_ds else Dummy(array_format=format)
stats = ds.get_statistics()

for key in stats.keys():
Expand Down

0 comments on commit 2a6e3ef

Please sign in to comment.