Skip to content

Commit

Permalink
Merge testing + Add imports
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Apr 8, 2024
2 parents 4bec82d + ed73e7d commit a5ced0a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 17 deletions.
21 changes: 20 additions & 1 deletion openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def get_project_root():
_lazy_imports_obj = {
"__version__": "openqdc._version",
"BaseDataset": "openqdc.datasets.base",
# POTENTIAL
"ANI1": "openqdc.datasets.potential.ani",
"ANI1CCX": "openqdc.datasets.potential.ani",
"ANI1X": "openqdc.datasets.potential.ani",
Expand All @@ -32,12 +33,23 @@ def get_project_root():
"SolvatedPeptides": "openqdc.datasets.potential.solvated_peptides",
"WaterClusters": "openqdc.datasets.potential.waterclusters3_30",
"TMQM": "openqdc.datasets.potential.tmqm",
"Dummy": "openqdc.datasets.potential.dummy",
"PCQM_B3LYP": "openqdc.datasets.potential.pcqm",
"PCQM_PM6": "openqdc.datasets.potential.pcqm",
"RevMD17": "openqdc.datasets.potential.revmd17",
"Transition1X": "openqdc.datasets.potential.transition1x",
"MultixcQM9": "openqdc.datasets.potential.multixcqm9",
# INTERACTION
"DES5M": "openqdc.datasets.interaction.des",
"DES370K": "openqdc.datasets.interaction.des",
"DESS66": "openqdc.datasets.interaction.des",
"DESS66x8": "openqdc.datasets.interaction.des",
"L7": "openqdc.datasets.interaction.l7",
"X40": "openqdc.datasets.interaction.x40",
"Metcalf": "openqdc.datasets.interaction.metcalf",
"Splinter": "openqdc.datasets.interaction.splinter",
# DEBUG
"Dummy": "openqdc.datasets.potential.dummy",
# ALL
"AVAILABLE_DATASETS": "openqdc.datasets",
"AVAILABLE_POTENTIAL_DATASETS": "openqdc.datasets.potential",
"AVAILABLE_INTERACTION_DATASETS": "openqdc.datasets.interaction",
Expand Down Expand Up @@ -75,6 +87,13 @@ def __dir__():
from ._version import __version__ # noqa
from .datasets import AVAILABLE_DATASETS # noqa
from .datasets.base import BaseDataset # noqa

# INTERACTION
from .datasets.interaction.des import DES5M, DES370K, DESS66, DESS66x8 # noqa
from .datasets.interaction.l7 import L7 # noqa
from .datasets.interaction.metcalf import Metcalf # noqa
from .datasets.interaction.splinter import Splinter # noqa
from .datasets.interaction.x40 import X40 # noqa
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1X # noqa
from .datasets.potential.comp6 import COMP6 # noqa
from .datasets.potential.dummy import Dummy # noqa
Expand Down
3 changes: 2 additions & 1 deletion openqdc/datasets/interaction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base import BaseInteractionDataset # noqa
from .des import DES5M, DES370K, DESS66, DESS66x8
from .l7x40 import L7, X40
from .l7 import L7
from .metcalf import Metcalf
from .splinter import Splinter
from .x40 import X40

AVAILABLE_INTERACTION_DATASETS = {
"des5m": DES5M,
Expand Down
3 changes: 0 additions & 3 deletions openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ def __getitem__(self, idx: int):
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))

e0 = self._convert_array(np.array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32))
formation_energies = energies - e0.sum(axis=0)

bunch = Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=e0,
energies=energies,
formation_energies=formation_energies,
per_atom_formation_energies=formation_energies / len(z),
name=name,
subset=subset,
forces=forces,
Expand Down
5 changes: 1 addition & 4 deletions openqdc/datasets/interaction/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ class DummyInteraction(BaseInteractionDataset):

__name__ = "dummy_interaction"
__energy_methods__ = [InteractionMethod.SAPT0_AUG_CC_PVDDZ, InteractionMethod.CCSD_T_CC_PVDZ]
__force_mask__ = [False, True]
__force_mask__ = [False, False]
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"

energy_target_names = [f"energy{i}" for i in range(len(__energy_methods__))]

force_target_names = [f"forces{i}" for i in range(len(__force_mask__))]
__isolated_atom_energies__ = []
__average_n_atoms__ = None

Expand Down Expand Up @@ -48,7 +47,6 @@ def setup_dummy(self):
name = [f"dummy_{i}" for i in range(len(self))]
subset = ["dummy" for i in range(len(self))]
energies = np.random.rand(len(self), len(self.energy_methods))
forces = np.concatenate([np.random.randn(size, 3, len(self.force_methods)) * 100 for size in n_atoms])
self.data = dict(
n_atoms=n_atoms,
position_idx_range=position_idx_range,
Expand All @@ -57,7 +55,6 @@ def setup_dummy(self):
subset=subset,
energies=energies,
n_atoms_first=n_atoms_first,
forces=forces,
)
self.__average_nb_atoms__ = self.data["n_atoms"].mean()

Expand Down
3 changes: 2 additions & 1 deletion openqdc/datasets/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def to_dict(self):

def transform(self, func):
for k, v in self.to_dict().items():
setattr(self, k, func(v))
if v is not None:
setattr(self, k, func(v))


@dataclass
Expand Down
21 changes: 14 additions & 7 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from openqdc.utils.package_utils import has_package


# start by removing any cached data
@pytest.fixture(autouse=True)
def clean_before_run():
# start by removing any cached data
Expand Down Expand Up @@ -66,12 +67,15 @@ def test_dummy_array_format(interaction_ds, format):
"energies",
"forces",
"e0",
"formation_energies",
"per_atom_formation_energies",
]
if not interaction_ds:
# additional keys returned from the potential dataset
keys.extend(["formation_energies", "per_atom_formation_energies"])

data = ds[0]
for key in keys:
if data[key] is None:
continue
assert isinstance(data[key], format_to_type[format])


Expand Down Expand Up @@ -129,11 +133,12 @@ def test_force_statistics_shapes(ds, request):
keys = ["mean", "std", "component_mean", "component_std", "component_rms"]
assert all(k in forces_stats for k in keys)

assert forces_stats["mean"].shape == (1, num_force_methods)
assert forces_stats["std"].shape == (1, num_force_methods)
assert forces_stats["component_mean"].shape == (3, num_force_methods)
assert forces_stats["component_std"].shape == (3, num_force_methods)
assert forces_stats["component_rms"].shape == (3, num_force_methods)
if len(ds.force_methods) > 0:
assert forces_stats["mean"].shape == (1, num_force_methods)
assert forces_stats["std"].shape == (1, num_force_methods)
assert forces_stats["component_mean"].shape == (3, num_force_methods)
assert forces_stats["component_std"].shape == (3, num_force_methods)
assert forces_stats["component_rms"].shape == (3, num_force_methods)


@pytest.mark.parametrize("interaction_ds", [False, True])
Expand All @@ -147,4 +152,6 @@ def test_stats_array_format(interaction_ds, format):

for key in stats.keys():
for k, v in stats[key].items():
if v is None:
continue
assert isinstance(v, format_to_type[format])

0 comments on commit a5ced0a

Please sign in to comment.