Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolated at en #18

Merged
merged 22 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/openqdc/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"WaterClusters": "openqdc.datasets.waterclusters3_30",
"TMQM": "openqdc.datasets.tmqm",
"Dummy": "openqdc.datasets.dummy",
"PCQM_B3LYP": "openqdc.datasets.pcqm",
"PCQM_PM6": "openqdc.datasets.pcqm",
"Transition1X": "openqdc.datasets.transition1x",
}

_lazy_imports_mod = {}
Expand Down Expand Up @@ -68,12 +71,14 @@ def __dir__():
from .molecule3d import Molecule3D # noqa
from .nabladft import NablaDFT # noqa
from .orbnet_denali import OrbnetDenali # noqa
from .pcqm import PCQM_B3LYP, PCQM_PM6 # noqa
from .qm7x import QM7X # noqa
from .qmugs import QMugs # noqa
from .sn2_rxn import SN2RXN # noqa
from .solvated_peptides import SolvatedPeptides # noqa
from .spice import Spice # noqa
from .tmqm import TMQM # noqa
from .transition1x import Transition1X # noqa
from .waterclusters3_30 import WaterClusters # noqa

__all__ = [
Expand All @@ -95,5 +100,8 @@ def __dir__():
"SolvatedPeptides",
"WaterClusters",
"TMQM",
"PCQM_B3LYP",
"PCQM_PM6",
"Transition1X",
"Dummy",
]
20 changes: 0 additions & 20 deletions src/openqdc/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,3 @@ class ANI1X(ANI1):

def convert_forces(self, x):
return super().convert_forces(x) * 0.529177249 # correct the Dataset error


if __name__ == "__main__":
for data_class in [
ANI1,
# ANI1CCX,
# ANI1X
]:
data = data_class()
n = len(data)

for i in np.random.choice(n, 3, replace=False):
x = data[i]
print(x.name, x.subset, end=" ")
for k in x:
if x[k] is not None:
print(k, x[k].shape, end=" ")

print()
exit()
37 changes: 22 additions & 15 deletions src/openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,27 @@ class BaseDataset(torch.utils.data.Dataset):
__fn_forces__ = lambda x: x

def __init__(
self, energy_unit: Optional[str] = None, distance_unit: Optional[str] = None, cache_dir: Optional[str] = None
self,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
prtos marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
set_cache_dir(cache_dir)
self.data = None
self._set_units(energy_unit, distance_unit)
if not self.is_preprocessed():
entries = self.read_raw_entries()
res = self.collate_list(entries)
self.save_preprocess(res)
self.read_preprocess()
self.__isolated_atom_energies__ = (
[IsolatedAtomEnergyFactory.get(en_method) for en_method in self.__energy_methods__]
if self.__energy_methods__
else None
)
logger.info("This dataset not available. Please open an issue on Github for the team to look into it.")
# entries = self.read_raw_entries()
# res = self.collate_list(entries)
# self.save_preprocess(res)
prtos marked this conversation as resolved.
Show resolved Hide resolved
else:
self.read_preprocess(overwrite_local_cache=overwrite_local_cache)
self.__isolated_atom_energies__ = (
[IsolatedAtomEnergyFactory.get(en_method) for en_method in self.__energy_methods__]
if self.__energy_methods__
else None
)

@property
def energy_unit(self):
Expand Down Expand Up @@ -189,11 +195,6 @@ def collate_list(self, list_entries):
def save_preprocess(self, data_dict):
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
logger.info(
f"Dataset {self.__name__} data with the following units:\n"
f"Energy: {self.energy_unit}, Distance: {self.distance_unit}, "
f"Forces: {self.force_unit if self.__force_methods__ else 'None'}"
)
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap")
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
Expand Down Expand Up @@ -245,6 +246,12 @@ def is_preprocessed(self):
predicats += [copy_exists(p_join(self.preprocess_path, f"{x}.npz")) for x in ["name", "subset"]]
return all(predicats)

def preprocess(self):
if not self.is_preprocessed():
entries = self.read_raw_entries()
res = self.collate_list(entries)
self.save_preprocess(res)

def save_xyz(self, idx: int, path: Optional[str] = None):
prtos marked this conversation as resolved.
Show resolved Hide resolved
if path is None:
path = os.getcwd()
Expand Down
2 changes: 0 additions & 2 deletions src/openqdc/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class COMP6(BaseDataset):

__name__ = "comp6"

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)
# watchout that forces are stored as -grad(E)
__energy_unit__ = "kcal/mol"
__distance_unit__ = "bohr" # bohr
Expand Down
1 change: 1 addition & 0 deletions src/openqdc/datasets/orbnet_denali.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import atom_table


Expand Down
53 changes: 30 additions & 23 deletions src/openqdc/datasets/pcqm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,43 @@ def read_archive(path):
return res


class PubchemQC(BaseDataset):
__name__ = "pubchemqc"
__energy_methods__ = [
"b3lyp/6-31g*",
"pm6",
]

__energy_unit__ = "ev"
__distance_unit__ = "ang"
__forces_unit__ = "ev/ang"

energy_target_names = [
"b3lyp",
"pm6",
]
def read_preprocessed_archive(path):
res = []
if os.path.exists(path):
with open(path, "rb") as f:
res = pkl.load(f)
return res


class PCQM_PM6(BaseDataset):
__name__ = "pubchemqc_pm6"
__energy_methods__ = ["pm6"]

energy_target_names = ["pm6"]

def _read_raw_(self, part):
arxiv_paths = glob(p_join(self.root, f"{part}", "*.tar.gz"))
print(len(arxiv_paths))
samples = dm.parallelized(read_archive, arxiv_paths, n_jobs=-1, progress=True, scheduler="threads")
res = sum(samples, [])
print(len(res))
exit()
return res
__force_methods__ = []
force_target_names = []

def __init__(self, energy_unit=None, distance_unit=None) -> None:
super().__init__(energy_unit=energy_unit, distance_unit=distance_unit)

@property
def root(self):
return p_join(get_local_cache(), "pubchemqc")

def collate_list(self, list_entries, partial=False):
# default partial=False is necessary for compatibility with the base class
if partial:
predicat = list_entries is not None and len(list_entries) > 0
list_entries = [x for x in list_entries if x is not None]
return super().collate_list(list_entries) if predicat else None
else:
n = 0
for i in range(len(list_entries)):
list_entries[i]["position_idx_range"] += n
n += list_entries[i]["position_idx_range"].max()
res = {key: np.concatenate([r[key] for r in list_entries], axis=0) for key in list_entries[0]}
return res

def read_raw_entries(self):
arxiv_paths = glob(p_join(self.root, f"{self.__energy_methods__[0]}", "*.pkl"))
Expand Down
1 change: 1 addition & 0 deletions src/openqdc/datasets/qm7x.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.io import load_hdf5_file


Expand Down
3 changes: 2 additions & 1 deletion src/openqdc/datasets/qmugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import get_atomic_number_and_charge


Expand Down Expand Up @@ -52,7 +53,7 @@ class QMugs(BaseDataset):
"""

__name__ = "qmugs"
__energy_methods__ = ["gfn2_xtb", "wb97x-d-D/def2-svp"]
__energy_methods__ = ["gfn2_xtb", "wb97x-d/def2-svp"]
__energy_unit__ = "hartree"
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"
Expand Down
4 changes: 3 additions & 1 deletion src/openqdc/datasets/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def read_record(r):
name=np.array([smiles] * n_confs),
subset=np.array([Spice.subset_mapping[subset]] * n_confs),
energies=r[Spice.energy_target_names[0]][:][:, None].astype(np.float32),
forces=r[Spice.force_target_names[0]][:].reshape(-1, 3, 1) * (-1.0),
forces=r[Spice.force_target_names[0]][:].reshape(
-1, 3, 1
), # forces -ve of energy gradient but the -1.0 is done in the convert_forces method
atomic_inputs=np.concatenate(
(x[None, ...].repeat(n_confs, axis=0), positions), axis=-1, dtype=np.float32
).reshape(-1, 5),
Expand Down
1 change: 1 addition & 0 deletions src/openqdc/datasets/tmqm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import atom_table


Expand Down
3 changes: 0 additions & 3 deletions src/openqdc/datasets/transition1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ class Transition1X(BaseDataset):
"wB97x_6-31G(d).forces",
]

def __init__(self) -> None:
super().__init__()

def read_raw_entries(self):
raw_path = p_join(self.root, "Transition1x.h5")
f = load_hdf5_file(raw_path)["data"]
Expand Down
1 change: 1 addition & 0 deletions src/openqdc/datasets/waterclusters3_30.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import atom_table

# we could use ase.io.read to read extxyz files
Expand Down
20 changes: 20 additions & 0 deletions src/openqdc/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,23 @@ def dict_to_atoms(d: dict):
pos, atomic_numbers, charges = d.pop("positions"), d.pop("atomic_numbers"), d.pop("charges")
at = Atoms(positions=pos, numbers=atomic_numbers, charges=charges, info=d)
return at


def print_h5_tree(val, pre=""):
items = len(val)
for key, val in val.items():
items -= 1
if items == 0:
# the last item
if type(val) == h5py._hl.group.Group:
print(pre + "└── " + key)
print_h5_tree(val, pre + " ")
else:
print(pre + "└── " + key + " (%d)" % len(val))
else:
if type(val) == h5py._hl.group.Group:
print(pre + "├── " + key)
print_h5_tree(val, pre + "│ ")
else:
# pass
print(pre + "├── " + key + " (%d)" % len(val))
You are viewing a condensed version of this merge commit. You can view the full changes here.