Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unique enums + initial structure for api endpoint #86

Merged
merged 16 commits into from
Jun 8, 2024
Merged
3 changes: 2 additions & 1 deletion openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_project_root():
"ANI1CCX": "openqdc.datasets.potential.ani",
"ANI1CCX_V2": "openqdc.datasets.potential.ani",
"ANI1X": "openqdc.datasets.potential.ani",
"ANI2": "openqdc.datasets.potential.ani",
"Spice": "openqdc.datasets.potential.spice",
"SpiceV2": "openqdc.datasets.potential.spice",
"SpiceVL2": "openqdc.datasets.potential.spice",
Expand Down Expand Up @@ -100,7 +101,7 @@ def __dir__():
from .datasets.interaction.metcalf import Metcalf
from .datasets.interaction.splinter import Splinter
from .datasets.interaction.x40 import X40
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2
from .datasets.potential.comp6 import COMP6
from .datasets.potential.dummy import Dummy
from .datasets.potential.gdml import GDML
Expand Down
26 changes: 16 additions & 10 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
app = typer.Typer(help="OpenQDC CLI")


def sanitize(dictionary):
return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()}


SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS)


def exist_dataset(dataset):
if dataset not in AVAILABLE_DATASETS:
if dataset not in sanitize(AVAILABLE_DATASETS):
logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.")
return False
return True
Expand Down Expand Up @@ -57,10 +64,10 @@ def download(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
if AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
if SANITIZED_AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
logger.info(f"{dataset} is already cached. Skipping download")
else:
AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)
SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)


@app.command()
Expand Down Expand Up @@ -115,18 +122,17 @@ def fetch(
openqdc fetch Spice
"""
if datasets[0].lower() == "all":
dataset_names = AVAILABLE_DATASETS
dataset_names = list(sanitize(AVAILABLE_DATASETS).keys())
elif datasets[0].lower() == "potential":
dataset_names = AVAILABLE_POTENTIAL_DATASETS
dataset_names = list(sanitize(AVAILABLE_POTENTIAL_DATASETS).keys())
elif datasets[0].lower() == "interaction":
dataset_names = AVAILABLE_INTERACTION_DATASETS
dataset_names = list(sanitize(AVAILABLE_INTERACTION_DATASETS).keys())
else:
dataset_names = datasets

for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)):
if exist_dataset(dataset):
try:
AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
SANITIZED_AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
except Exception as e:
logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}")

Expand All @@ -152,9 +158,9 @@ def preprocess(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
logger.info(f"Preprocessing {AVAILABLE_DATASETS[dataset].__name__}")
logger.info(f"Preprocessing {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
try:
AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
SANITIZED_AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
except Exception as e:
logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?")
raise e
Expand Down
6 changes: 3 additions & 3 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "numpy",
energy_type: str = "formation",
energy_type: Optional[str] = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
Expand All @@ -112,7 +112,7 @@ def __init__(
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
energy_type
Type of isolated atom energy to use for the dataset. Default: "formation"
Supported types: ["formation", "regression", "null"]
Supported types: ["formation", "regression", "null", None]
overwrite_local_cache
Whether to overwrite the locally cached dataset.
cache_dir
Expand All @@ -133,7 +133,7 @@ def __init__(
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self.energy_type = energy_type
self.energy_type = energy_type if energy_type is not None else "null"
self.refit_e0s = recompute_statistics or overwrite_local_cache
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
Expand Down
3 changes: 1 addition & 2 deletions openqdc/datasets/energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from loguru import logger

from openqdc.methods.enums import PotentialMethod
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS, MAX_CHARGE_NUMBER
from openqdc.utils.io import load_pkl, save_pkl
from openqdc.utils.regressor import Regressor

POSSIBLE_ENERGIES = ["formation", "regression", "null"]
MAX_CHARGE_NUMBER = 21


def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface":
Expand Down
16 changes: 8 additions & 8 deletions openqdc/datasets/interaction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from .x40 import X40

AVAILABLE_INTERACTION_DATASETS = {
"des5m": DES5M,
"des370k": DES370K,
"dess66": DESS66,
"dess66x8": DESS66x8,
"l7": L7,
"metcalf": Metcalf,
"splinter": Splinter,
"x40": X40,
"DES5M": DES5M,
"DES370K": DES370K,
"DESS66": DESS66,
"DESS66x8": DESS66x8,
"L7": L7,
"Metcalf": Metcalf,
"Splinter": Splinter,
"X40": X40,
}
60 changes: 30 additions & 30 deletions openqdc/datasets/potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
from .ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2
from .comp6 import COMP6
from .dummy import Dummy
from .gdml import GDML
Expand All @@ -21,33 +21,33 @@
from .waterclusters3_30 import WaterClusters

AVAILABLE_POTENTIAL_DATASETS = {
"ani1": ANI1,
"ani1ccx": ANI1CCX,
"ani1ccxv2": ANI1CCX_V2,
"ani1x": ANI1X,
"comp6": COMP6,
"gdml": GDML,
"geom": GEOM,
"iso17": ISO17,
"molecule3d": Molecule3D,
"nabladft": NablaDFT,
"orbnetdenali": OrbnetDenali,
"pcqmb3lyp": PCQM_B3LYP,
"pcqmpm6": PCQM_PM6,
"qm7x": QM7X,
"qm7xv2": QM7X_V2,
"qmugs": QMugs,
"qmugsv2": QMugs_V2,
"sn2rxn": SN2RXN,
"solvatedpeptides": SolvatedPeptides,
"spice": Spice,
"spicev2": SpiceV2,
"spicevl2": SpiceVL2,
"tmqm": TMQM,
"transition1x": Transition1X,
"watercluster": WaterClusters,
"multixcqm9": MultixcQM9,
"multixcqm9v2": MultixcQM9_V2,
"revmd17": RevMD17,
"md22": MD22,
"ANI1": ANI1,
"ANI1CCX": ANI1CCX,
"ANI1CCX_V2": ANI1CCX_V2,
"ANI1X": ANI1X,
"COMP6": COMP6,
"GDML": GDML,
"GEOM": GEOM,
"ISO17": ISO17,
"Molecule3D": Molecule3D,
"NablaDFT": NablaDFT,
"OrbnetDenali": OrbnetDenali,
"PCQM_B3LYP": PCQM_B3LYP,
"PCQM_PM6": PCQM_PM6,
"QM7X": QM7X,
"QM7X_V2": QM7X_V2,
"QMugs": QMugs,
"QMugs_V2": QMugs_V2,
"SN2RXN": SN2RXN,
"SolvatedPeptides": SolvatedPeptides,
"Spice": Spice,
"SpiceV2": SpiceV2,
"SpiceVL2": SpiceVL2,
"TMQM": TMQM,
"Transition1X": Transition1X,
"WaterClusters": WaterClusters,
"MultixcQM9": MultixcQM9,
"MultixcQM9_V2": MultixcQM9_V2,
"RevMD17": RevMD17,
"MD22": MD22,
}
80 changes: 79 additions & 1 deletion openqdc/datasets/potential/ani.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
import os
from os.path import join as p_join

import numpy as np

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils import read_qc_archive_h5
from openqdc.utils import load_hdf5_file, read_qc_archive_h5
from openqdc.utils.io import get_local_cache


def read_ani2_h5(raw_path):
h5f = load_hdf5_file(raw_path)
samples = []
for _, props in h5f.items():
samples.append(extract_ani2_entries(props))
return samples


def extract_ani2_entries(properties):
coordinates = properties["coordinates"]
species = properties["species"]
forces = properties["forces"]
energies = properties["energies"]
n_atoms = coordinates.shape[1]
n_entries = coordinates.shape[0]
flattened_coordinates = coordinates[:].reshape((-1, 3))
xs = np.stack((species[:].flatten(), np.zeros(flattened_coordinates.shape[0])), axis=-1)
res = dict(
name=np.array(["ANI2"] * n_entries),
subset=np.array([str(n_atoms)] * n_entries),
energies=energies[:].reshape((-1, 1)).astype(np.float64),
atomic_inputs=np.concatenate((xs, flattened_coordinates), axis=-1, dtype=np.float32),
n_atoms=np.array([n_atoms] * n_entries, dtype=np.int32),
forces=forces[:].reshape(-1, 3, 1).astype(np.float32),
)
return res


class ANI1(BaseDataset):
"""
The ANI-1 dataset is a collection of 22 x 10^6 structural conformations from 57,000 distinct small
Expand Down Expand Up @@ -176,3 +206,51 @@ class ANI1CCX_V2(ANI1CCX):

__energy_methods__ = ANI1CCX.__energy_methods__ + [PotentialMethod.PM6, PotentialMethod.GFN2_XTB]
energy_target_names = ANI1CCX.energy_target_names + ["PM6", "GFN2"]


class ANI2(ANI1):
""" """

__name__ = "ani2"
__energy_unit__ = "hartree"
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"

__energy_methods__ = [
# PotentialMethod.NONE, # "b973c/def2mtzvp",
PotentialMethod.WB97X_6_31G_D, # "wb97x/631gd", # PAPER DATASET
# PotentialMethod.NONE, # "wb97md3bj/def2tzvpp",
# PotentialMethod.NONE, # "wb97mv/def2tzvpp",
# PotentialMethod.NONE, # "wb97x/def2tzvpp",
]

energy_target_names = [
# "b973c/def2mtzvp",
"wb97x/631gd",
# "wb97md3bj/def2tzvpp",
# "wb97mv/def2tzvpp",
# "wb97x/def2tzvpp",
]

force_target_names = ["wb97x/631gd"] # "b973c/def2mtzvp",

__force_mask__ = [True]
__links__ = { # "ANI-2x-B973c-def2mTZVP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-B973c-def2mTZVP.tar.gz?download=1", # noqa
# "ANI-2x-wB97MD3BJ-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97MD3BJ-def2TZVPP.tar.gz?download=1", # noqa
# "ANI-2x-wB97MV-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97MV-def2TZVPP.tar.gz?download=1", # noqa
"ANI-2x-wB97X-631Gd.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-631Gd.tar.gz?download=1", # noqa
# "ANI-2x-wB97X-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-def2TZVPP.tar.gz?download=1", # noqa
}

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return x

def read_raw_entries(self):
samples = []
for lvl_theory in self.__links__.keys():
raw_path = p_join(self.root, "final_h5", f"{lvl_theory.split('.')[0]}.h5")
samples.extend(read_ani2_h5(raw_path))
return samples
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class COMP6(BaseDataset):

# watchout that forces are stored as -grad(E)
__energy_unit__ = "kcal/mol"
__distance_unit__ = "bohr" # bohr
__forces_unit__ = "kcal/mol/bohr"
__distance_unit__ = "ang" # angstorm
__forces_unit__ = "kcal/mol/ang"

__energy_methods__ = [
PotentialMethod.WB97X_6_31G_D, # "wb97x/6-31g*",
Expand Down
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/gdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class GDML(BaseDataset):
]

__energy_unit__ = "kcal/mol"
__distance_unit__ = "bohr"
__forces_unit__ = "kcal/mol/bohr"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"
__links__ = {
"gdb7_9.hdf5.gz": "https://zenodo.org/record/3588361/files/208.hdf5.gz",
"gdb10_13.hdf5.gz": "https://zenodo.org/record/3588364/files/209.hdf5.gz",
Expand Down
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/iso_17.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class ISO17(BaseDataset):
]

__energy_unit__ = "ev"
__distance_unit__ = "bohr" # bohr
__forces_unit__ = "ev/bohr"
__distance_unit__ = "ang"
__forces_unit__ = "ev/ang"
__links__ = {"iso_17.hdf5.gz": "https://zenodo.org/record/3585907/files/216.hdf5.gz"}

def __smiles_converter__(self, x):
Expand Down
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/molecule3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def read_mol(mol: Chem.rdchem.Mol, energy: float) -> Dict[str, np.ndarray]:
res = dict(
name=np.array([smiles]),
subset=np.array(["molecule3d"]),
energies=np.array([energy]).astype(np.float32)[:, None],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float64),
energies=np.array([energy]).astype(np.float64)[:, None],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)

Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/potential/qm7x.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class QM7X(BaseDataset):

__energy_methods__ = [PotentialMethod.PBE0_DEF2_TZVP, PotentialMethod.DFT3B] # "pbe0/def2-tzvp", "dft3b"]

energy_target_names = ["ePBE0", "eMBD"]
energy_target_names = ["ePBE0+MBD", "eDFTB+MBD"]

__force_mask__ = [True, True]

Expand Down
Loading
Loading