Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Interaction and Better Testing #71

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
31beb71
refactor interaction and initial testing
Mar 26, 2024
dccf676
minor changes
Mar 26, 2024
2ab64aa
dummy modification
Mar 26, 2024
189ab90
undo changes in interaction dataset, and minor change in shape
Mar 29, 2024
282dc91
changed super class to BaseInteractionDataset
Apr 2, 2024
701ef1e
Merge branch 'release' into testing
Apr 3, 2024
afea053
further simplified and rebase
Apr 3, 2024
ebc2adf
fixes
Apr 5, 2024
7ffd0b1
Merge remote-tracking branch 'origin/release' into testing
Apr 5, 2024
d15e9cf
Merge remote-tracking branch 'origin/release' into testing
Apr 5, 2024
ed8e264
Updated metcalf
Apr 5, 2024
18bc79c
bug fix and simplifying interaction dataset
Apr 6, 2024
2a6e3ef
Updated tests for interaction datasets
Apr 6, 2024
7493273
removed stale stats in dummy interaction
Apr 6, 2024
ed73e7d
changes based on comments
Apr 6, 2024
0359022
Clean metcalf
FNTwin Apr 6, 2024
33fa342
Simplification
FNTwin Apr 6, 2024
cd486a8
cleaned des
FNTwin Apr 6, 2024
80d7371
Simplified des dataset
FNTwin Apr 6, 2024
f3d205c
removed redundant dataset files
FNTwin Apr 6, 2024
da4fece
DES inerithance
FNTwin Apr 6, 2024
71ff741
Removed des and improved des naming
FNTwin Apr 6, 2024
f6e12e1
DES fixes
FNTwin Apr 6, 2024
3328a65
Removed comments
FNTwin Apr 6, 2024
8b28d59
X40 and L70
FNTwin Apr 6, 2024
8595fd8
Safe opening
FNTwin Apr 6, 2024
ca1b4af
Moved X40 in L7 and removed x40.py
FNTwin Apr 6, 2024
4bec82d
Moved Yaml utils to _utils.py, L7 + X40 interface
FNTwin Apr 7, 2024
a5ced0a
Merge testing + Add imports
FNTwin Apr 8, 2024
a21963e
Merge pull request #79 from OpenDrugDiscovery/interaction_impr
shenoynikhil Apr 8, 2024
3303f95
better convert function and n_body_first to ptr
Apr 12, 2024
6f033cf
Updated splinter reading from -1 to nan
Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def data_keys(self):
keys.remove("forces")
return keys

@property
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
def pkl_data_keys(self):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
return ["name", "subset", "n_atoms"]

@property
def data_types(self):
return {
Expand All @@ -322,8 +326,8 @@ def data_shapes(self):
return {
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"energies": (-1, len(self.energy_target_names)),
"forces": (-1, 3, len(self.force_target_names)),
"energies": (-1, len(self.energy_methods)),
"forces": (-1, 3, len(self.force_methods)),
}

@property
Expand Down Expand Up @@ -420,8 +424,13 @@ def save_preprocess(self, data_dict):

# save smiles and subset
local_path = p_join(self.preprocess_path, "props.pkl")
for key in ["name", "subset"]:
data_dict[key] = np.unique(data_dict[key], return_inverse=True)
# assert that required keys are present in data_dict
assert all([key in data_dict for key in self.pkl_data_keys])
for key in data_dict:
if key not in self.data_keys:
x = data_dict[key]
x[x == None] = -1 # noqa
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
data_dict[key] = np.unique(data_dict[key], return_inverse=True)

with open(local_path, "wb") as f:
pkl.dump(data_dict, f)
Expand Down Expand Up @@ -457,7 +466,10 @@ def read_preprocess(self, overwrite_local_cache=False):
pull_locally(filename, overwrite=overwrite_local_cache)
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in ["name", "subset", "n_atoms"]:
all_pkl_keys = set(tmp.keys()) - set(self.data_keys)
# assert required pkl_keys are present in all_pkl_keys
assert all([key in all_pkl_keys for key in self.pkl_data_keys])
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
for key in all_pkl_keys:
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
Expand Down
78 changes: 6 additions & 72 deletions openqdc/datasets/interaction/base.py
Copy link
Collaborator

@FNTwin FNTwin Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently trying to load any interaction dataset will get you an error due to the:
if not self.is_preprocessed() failing due to the naming.

In the bucket they were written L7 and X40 (upper case). We should always have the sanitize name on lower case. As we need to postprocess it again to have the new keys. It will fix itself

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, need to push new changes.

Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import os
import pickle as pkl
from os.path import join as p_join
from typing import Dict, List, Optional

import numpy as np
from ase.io.extxyz import write_extxyz
from loguru import logger
from sklearn.utils import Bunch

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_CHARGE, NB_ATOMIC_FEATURES
from openqdc.utils.io import pull_locally, push_remote, to_atoms
from openqdc.utils.constants import MAX_CHARGE
from openqdc.utils.io import to_atoms


class BaseInteractionDataset(BaseDataset):
__energy_type__ = []

@property
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
def pkl_data_keys(self):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
return ["name", "subset", "n_atoms", "n_atoms_first"]

def collate_list(self, list_entries: List[Dict]):
# concatenate entries
res = {
Expand All @@ -31,24 +33,6 @@ def collate_list(self, list_entries: List[Dict]):

return res

@property
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
def data_shapes(self):
return {
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"energies": (-1, len(self.__energy_methods__)),
"forces": (-1, 3, len(self.force_target_names)),
}

@property
def data_types(self):
return {
"atomic_inputs": np.float32,
"position_idx_range": np.int32,
"energies": np.float32,
"forces": np.float32,
}

def __getitem__(self, idx: int):
shift = MAX_CHARGE
p_start, p_end = self.data["position_idx_range"][idx]
Expand Down Expand Up @@ -79,56 +63,6 @@ def __getitem__(self, idx: int):
n_atoms_first=n_atoms_first,
)

def save_preprocess(self, data_dict):
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap")
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
out[:] = data_dict.pop(key)[:]
out.flush()
push_remote(local_path, overwrite=True)

# save all other keys in props.pkl
local_path = p_join(self.preprocess_path, "props.pkl")
for key in data_dict:
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
if key not in self.data_keys:
x = data_dict[key]
x[x == None] = -1 # noqa
data_dict[key] = np.unique(x, return_inverse=True)

with open(local_path, "wb") as f:
pkl.dump(data_dict, f)
push_remote(local_path, overwrite=True)

def read_preprocess(self, overwrite_local_cache=False):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Reading preprocessed data.")
logger.info(
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
)
self.data = {}
for key in self.data_keys:
filename = p_join(self.preprocess_path, f"{key}.mmap")
pull_locally(filename, overwrite=overwrite_local_cache)
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(self.data_shapes[key])

filename = p_join(self.preprocess_path, "props.pkl")
pull_locally(filename, overwrite=overwrite_local_cache)
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in set(tmp.keys()) - set(self.data_keys):
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
else:
self.data[key] = x

for key in self.data:
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")

def get_ase_atoms(self, idx: int):
entry = self[idx]
at = to_atoms(entry["positions"], entry["atomic_numbers"])
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
94 changes: 94 additions & 0 deletions openqdc/datasets/interaction/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np

from openqdc.datasets.interaction.base import BaseInteractionDataset
from openqdc.methods import InteractionMethod
from openqdc.utils.constants import NOT_DEFINED


class DummyInteraction(BaseInteractionDataset):
"""
Dummy Interaction Dataset for Testing
"""

__name__ = "dummy"
__energy_methods__ = [InteractionMethod.SAPT0_AUG_CC_PVDDZ, InteractionMethod.CCSD_T_CC_PVDZ]
__force_mask__ = [False, True]
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"

energy_target_names = [f"energy{i}" for i in range(len(__energy_methods__))]

force_target_names = [f"forces{i}" for i in range(len(__force_mask__))]
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
__isolated_atom_energies__ = []
__average_n_atoms__ = None

def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit)

@property
def _stats(self):
return {
"formation": {
"energy": {
"mean": np.array([[-12.94348027, -9.83037297]]),
"std": np.array([[4.39971409, 3.3574188]]),
},
"forces": NOT_DEFINED,
},
"total": {
"energy": {
"mean": np.array([[-89.44242, -1740.5336]]),
"std": np.array([[29.599571, 791.48663]]),
},
"forces": NOT_DEFINED,
},
}

def setup_dummy(self):
n_atoms = np.array([np.random.randint(10, 30) for _ in range(len(self))])
n_atoms_first = np.array([np.random.randint(1, 10) for _ in range(len(self))])
position_idx_range = np.concatenate([[0], np.cumsum(n_atoms)]).repeat(2)[1:-1].reshape(-1, 2)
atomic_inputs = np.concatenate(
[
np.concatenate(
[
# z, c, x, y, z
np.random.randint(1, 100, size=(size, 1)),
np.random.randint(-1, 2, size=(size, 1)),
np.random.randn(size, 3),
],
axis=1,
)
for size in n_atoms
],
axis=0,
) # (sum(n_atoms), 5)
name = [f"dummy_{i}" for i in range(len(self))]
subset = ["dummy" for i in range(len(self))]
energies = np.random.rand(len(self), len(self.energy_methods))
forces = np.concatenate([np.random.randn(size, 3, len(self.force_methods)) * 100 for size in n_atoms])
self.data = dict(
n_atoms=n_atoms,
position_idx_range=position_idx_range,
name=name,
atomic_inputs=atomic_inputs,
subset=subset,
energies=energies,
n_atoms_first=n_atoms_first,
forces=forces,
)
self.__average_nb_atoms__ = self.data["n_atoms"].mean()

def read_preprocess(self, overwrite_local_cache=False):
return

def is_preprocessed(self):
return True

def read_raw_entries(self):
pass

def __len__(self):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
return 9999
17 changes: 5 additions & 12 deletions openqdc/datasets/potential/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,9 @@ def _stats(self):
},
}

def __init__(
self,
energy_unit=None,
distance_unit=None,
cache_dir=None,
) -> None:
try:
super().__init__(energy_unit=energy_unit, distance_unit=distance_unit, cache_dir=cache_dir)

except: # noqa
pass
self._set_isolated_atom_energies()
def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit)

def setup_dummy(self):
n_atoms = np.array([np.random.randint(1, 100) for _ in range(len(self))])
Expand Down Expand Up @@ -89,6 +79,9 @@ def setup_dummy(self):
)
self.__average_nb_atoms__ = self.data["n_atoms"].mean()

def read_preprocess(self, overwrite_local_cache=False):
return

def is_preprocessed(self):
return True

Expand Down
45 changes: 41 additions & 4 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
"""Path hack to make tests work."""

import pytest

from openqdc.datasets.interaction.dummy import DummyInteraction # noqa: E402
from openqdc.datasets.potential.dummy import Dummy # noqa: E402


def test_dummy():
ds = Dummy()
assert len(ds) > 10
assert ds[100]
@pytest.fixture
def dummy():
return Dummy()


@pytest.fixture
def dummy_interaction():
return DummyInteraction()


@pytest.mark.parametrize("cls", ["dummy", "dummy_interaction"])
def test_basic(cls, request):
# init
ds = request.getfixturevalue(cls)

# len
assert len(ds) == 9999

# __getitem__
assert ds[0]


@pytest.mark.parametrize("cls", ["dummy", "dummy_interaction"])
@pytest.mark.parametrize(
"normalization",
[
"formation",
"total",
# "residual_regression",
# "per_atom_formation",
# "per_atom_residual_regression"
],
)
def test_stats(cls, normalization, request):
ds = request.getfixturevalue(cls)

stats = ds.get_statistics(normalization=normalization)
assert stats is not None


# def test_is_at_factory():
Expand Down
Loading