Skip to content

Commit 6ffa749

Browse files
authored
Merge pull request #86 from OpenDrugDiscovery/downloader_add
unique enums + initial structure for api endpoint
2 parents 2669d21 + c2d13ad commit 6ffa749

21 files changed

+334
-132
lines changed

openqdc/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def get_project_root():
1919
"ANI1CCX": "openqdc.datasets.potential.ani",
2020
"ANI1CCX_V2": "openqdc.datasets.potential.ani",
2121
"ANI1X": "openqdc.datasets.potential.ani",
22+
"ANI2": "openqdc.datasets.potential.ani",
2223
"Spice": "openqdc.datasets.potential.spice",
2324
"SpiceV2": "openqdc.datasets.potential.spice",
2425
"SpiceVL2": "openqdc.datasets.potential.spice",
@@ -100,7 +101,7 @@ def __dir__():
100101
from .datasets.interaction.metcalf import Metcalf
101102
from .datasets.interaction.splinter import Splinter
102103
from .datasets.interaction.x40 import X40
103-
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
104+
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2
104105
from .datasets.potential.comp6 import COMP6
105106
from .datasets.potential.dummy import Dummy
106107
from .datasets.potential.gdml import GDML

openqdc/cli.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@
1616
app = typer.Typer(help="OpenQDC CLI")
1717

1818

19+
def sanitize(dictionary):
20+
return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()}
21+
22+
23+
SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS)
24+
25+
1926
def exist_dataset(dataset):
20-
if dataset not in AVAILABLE_DATASETS:
27+
if dataset not in sanitize(AVAILABLE_DATASETS):
2128
logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.")
2229
return False
2330
return True
@@ -57,10 +64,10 @@ def download(
5764
"""
5865
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
5966
if exist_dataset(dataset):
60-
if AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
67+
if SANITIZED_AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
6168
logger.info(f"{dataset} is already cached. Skipping download")
6269
else:
63-
AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)
70+
SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)
6471

6572

6673
@app.command()
@@ -115,18 +122,17 @@ def fetch(
115122
openqdc fetch Spice
116123
"""
117124
if datasets[0].lower() == "all":
118-
dataset_names = AVAILABLE_DATASETS
125+
dataset_names = list(sanitize(AVAILABLE_DATASETS).keys())
119126
elif datasets[0].lower() == "potential":
120-
dataset_names = AVAILABLE_POTENTIAL_DATASETS
127+
dataset_names = list(sanitize(AVAILABLE_POTENTIAL_DATASETS).keys())
121128
elif datasets[0].lower() == "interaction":
122-
dataset_names = AVAILABLE_INTERACTION_DATASETS
129+
dataset_names = list(sanitize(AVAILABLE_INTERACTION_DATASETS).keys())
123130
else:
124131
dataset_names = datasets
125-
126132
for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)):
127133
if exist_dataset(dataset):
128134
try:
129-
AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
135+
SANITIZED_AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
130136
except Exception as e:
131137
logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}")
132138

@@ -152,9 +158,9 @@ def preprocess(
152158
"""
153159
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
154160
if exist_dataset(dataset):
155-
logger.info(f"Preprocessing {AVAILABLE_DATASETS[dataset].__name__}")
161+
logger.info(f"Preprocessing {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
156162
try:
157-
AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
163+
SANITIZED_AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
158164
except Exception as e:
159165
logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?")
160166
raise e

openqdc/datasets/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
energy_unit: Optional[str] = None,
9090
distance_unit: Optional[str] = None,
9191
array_format: str = "numpy",
92-
energy_type: str = "formation",
92+
energy_type: Optional[str] = "formation",
9393
overwrite_local_cache: bool = False,
9494
cache_dir: Optional[str] = None,
9595
recompute_statistics: bool = False,
@@ -112,7 +112,7 @@ def __init__(
112112
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
113113
energy_type
114114
Type of isolated atom energy to use for the dataset. Default: "formation"
115-
Supported types: ["formation", "regression", "null"]
115+
Supported types: ["formation", "regression", "null", None]
116116
overwrite_local_cache
117117
Whether to overwrite the locally cached dataset.
118118
cache_dir
@@ -133,7 +133,7 @@ def __init__(
133133
self.recompute_statistics = recompute_statistics
134134
self.regressor_kwargs = regressor_kwargs
135135
self.transform = transform
136-
self.energy_type = energy_type
136+
self.energy_type = energy_type if energy_type is not None else "null"
137137
self.refit_e0s = recompute_statistics or overwrite_local_cache
138138
if not self.is_preprocessed():
139139
raise DatasetNotAvailableError(self.__name__)

openqdc/datasets/energies.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
from loguru import logger
88

99
from openqdc.methods.enums import PotentialMethod
10-
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS
10+
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS, MAX_CHARGE_NUMBER
1111
from openqdc.utils.io import load_pkl, save_pkl
1212
from openqdc.utils.regressor import Regressor
1313

1414
POSSIBLE_ENERGIES = ["formation", "regression", "null"]
15-
MAX_CHARGE_NUMBER = 21
1615

1716

1817
def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface":

openqdc/datasets/interaction/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from .x40 import X40
77

88
AVAILABLE_INTERACTION_DATASETS = {
9-
"des5m": DES5M,
10-
"des370k": DES370K,
11-
"dess66": DESS66,
12-
"dess66x8": DESS66x8,
13-
"l7": L7,
14-
"metcalf": Metcalf,
15-
"splinter": Splinter,
16-
"x40": X40,
9+
"DES5M": DES5M,
10+
"DES370K": DES370K,
11+
"DESS66": DESS66,
12+
"DESS66x8": DESS66x8,
13+
"L7": L7,
14+
"Metcalf": Metcalf,
15+
"Splinter": Splinter,
16+
"X40": X40,
1717
}
Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
1+
from .ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2
22
from .comp6 import COMP6
33
from .dummy import Dummy
44
from .gdml import GDML
@@ -21,33 +21,33 @@
2121
from .waterclusters3_30 import WaterClusters
2222

2323
AVAILABLE_POTENTIAL_DATASETS = {
24-
"ani1": ANI1,
25-
"ani1ccx": ANI1CCX,
26-
"ani1ccxv2": ANI1CCX_V2,
27-
"ani1x": ANI1X,
28-
"comp6": COMP6,
29-
"gdml": GDML,
30-
"geom": GEOM,
31-
"iso17": ISO17,
32-
"molecule3d": Molecule3D,
33-
"nabladft": NablaDFT,
34-
"orbnetdenali": OrbnetDenali,
35-
"pcqmb3lyp": PCQM_B3LYP,
36-
"pcqmpm6": PCQM_PM6,
37-
"qm7x": QM7X,
38-
"qm7xv2": QM7X_V2,
39-
"qmugs": QMugs,
40-
"qmugsv2": QMugs_V2,
41-
"sn2rxn": SN2RXN,
42-
"solvatedpeptides": SolvatedPeptides,
43-
"spice": Spice,
44-
"spicev2": SpiceV2,
45-
"spicevl2": SpiceVL2,
46-
"tmqm": TMQM,
47-
"transition1x": Transition1X,
48-
"watercluster": WaterClusters,
49-
"multixcqm9": MultixcQM9,
50-
"multixcqm9v2": MultixcQM9_V2,
51-
"revmd17": RevMD17,
52-
"md22": MD22,
24+
"ANI1": ANI1,
25+
"ANI1CCX": ANI1CCX,
26+
"ANI1CCX_V2": ANI1CCX_V2,
27+
"ANI1X": ANI1X,
28+
"COMP6": COMP6,
29+
"GDML": GDML,
30+
"GEOM": GEOM,
31+
"ISO17": ISO17,
32+
"Molecule3D": Molecule3D,
33+
"NablaDFT": NablaDFT,
34+
"OrbnetDenali": OrbnetDenali,
35+
"PCQM_B3LYP": PCQM_B3LYP,
36+
"PCQM_PM6": PCQM_PM6,
37+
"QM7X": QM7X,
38+
"QM7X_V2": QM7X_V2,
39+
"QMugs": QMugs,
40+
"QMugs_V2": QMugs_V2,
41+
"SN2RXN": SN2RXN,
42+
"SolvatedPeptides": SolvatedPeptides,
43+
"Spice": Spice,
44+
"SpiceV2": SpiceV2,
45+
"SpiceVL2": SpiceVL2,
46+
"TMQM": TMQM,
47+
"Transition1X": Transition1X,
48+
"WaterClusters": WaterClusters,
49+
"MultixcQM9": MultixcQM9,
50+
"MultixcQM9_V2": MultixcQM9_V2,
51+
"RevMD17": RevMD17,
52+
"MD22": MD22,
5353
}

openqdc/datasets/potential/ani.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,42 @@
11
import os
22
from os.path import join as p_join
33

4+
import numpy as np
5+
46
from openqdc.datasets.base import BaseDataset
57
from openqdc.methods import PotentialMethod
6-
from openqdc.utils import read_qc_archive_h5
8+
from openqdc.utils import load_hdf5_file, read_qc_archive_h5
79
from openqdc.utils.io import get_local_cache
810

911

12+
def read_ani2_h5(raw_path):
13+
h5f = load_hdf5_file(raw_path)
14+
samples = []
15+
for _, props in h5f.items():
16+
samples.append(extract_ani2_entries(props))
17+
return samples
18+
19+
20+
def extract_ani2_entries(properties):
21+
coordinates = properties["coordinates"]
22+
species = properties["species"]
23+
forces = properties["forces"]
24+
energies = properties["energies"]
25+
n_atoms = coordinates.shape[1]
26+
n_entries = coordinates.shape[0]
27+
flattened_coordinates = coordinates[:].reshape((-1, 3))
28+
xs = np.stack((species[:].flatten(), np.zeros(flattened_coordinates.shape[0])), axis=-1)
29+
res = dict(
30+
name=np.array(["ANI2"] * n_entries),
31+
subset=np.array([str(n_atoms)] * n_entries),
32+
energies=energies[:].reshape((-1, 1)).astype(np.float64),
33+
atomic_inputs=np.concatenate((xs, flattened_coordinates), axis=-1, dtype=np.float32),
34+
n_atoms=np.array([n_atoms] * n_entries, dtype=np.int32),
35+
forces=forces[:].reshape(-1, 3, 1).astype(np.float32),
36+
)
37+
return res
38+
39+
1040
class ANI1(BaseDataset):
1141
"""
1242
The ANI-1 dataset is a collection of 22 x 10^6 structural conformations from 57,000 distinct small
@@ -176,3 +206,51 @@ class ANI1CCX_V2(ANI1CCX):
176206

177207
__energy_methods__ = ANI1CCX.__energy_methods__ + [PotentialMethod.PM6, PotentialMethod.GFN2_XTB]
178208
energy_target_names = ANI1CCX.energy_target_names + ["PM6", "GFN2"]
209+
210+
211+
class ANI2(ANI1):
212+
""" """
213+
214+
__name__ = "ani2"
215+
__energy_unit__ = "hartree"
216+
__distance_unit__ = "ang"
217+
__forces_unit__ = "hartree/ang"
218+
219+
__energy_methods__ = [
220+
# PotentialMethod.NONE, # "b973c/def2mtzvp",
221+
PotentialMethod.WB97X_6_31G_D, # "wb97x/631gd", # PAPER DATASET
222+
# PotentialMethod.NONE, # "wb97md3bj/def2tzvpp",
223+
# PotentialMethod.NONE, # "wb97mv/def2tzvpp",
224+
# PotentialMethod.NONE, # "wb97x/def2tzvpp",
225+
]
226+
227+
energy_target_names = [
228+
# "b973c/def2mtzvp",
229+
"wb97x/631gd",
230+
# "wb97md3bj/def2tzvpp",
231+
# "wb97mv/def2tzvpp",
232+
# "wb97x/def2tzvpp",
233+
]
234+
235+
force_target_names = ["wb97x/631gd"] # "b973c/def2mtzvp",
236+
237+
__force_mask__ = [True]
238+
__links__ = { # "ANI-2x-B973c-def2mTZVP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-B973c-def2mTZVP.tar.gz?download=1", # noqa
239+
# "ANI-2x-wB97MD3BJ-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97MD3BJ-def2TZVPP.tar.gz?download=1", # noqa
240+
# "ANI-2x-wB97MV-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97MV-def2TZVPP.tar.gz?download=1", # noqa
241+
"ANI-2x-wB97X-631Gd.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-631Gd.tar.gz?download=1", # noqa
242+
# "ANI-2x-wB97X-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-def2TZVPP.tar.gz?download=1", # noqa
243+
}
244+
245+
def __smiles_converter__(self, x):
246+
"""util function to convert string to smiles: useful if the smiles is
247+
encoded in a different format than its display format
248+
"""
249+
return x
250+
251+
def read_raw_entries(self):
252+
samples = []
253+
for lvl_theory in self.__links__.keys():
254+
raw_path = p_join(self.root, "final_h5", f"{lvl_theory.split('.')[0]}.h5")
255+
samples.extend(read_ani2_h5(raw_path))
256+
return samples

openqdc/datasets/potential/comp6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class COMP6(BaseDataset):
2626

2727
# watchout that forces are stored as -grad(E)
2828
__energy_unit__ = "kcal/mol"
29-
__distance_unit__ = "bohr" # bohr
30-
__forces_unit__ = "kcal/mol/bohr"
29+
__distance_unit__ = "ang" # angstorm
30+
__forces_unit__ = "kcal/mol/ang"
3131

3232
__energy_methods__ = [
3333
PotentialMethod.WB97X_6_31G_D, # "wb97x/6-31g*",

openqdc/datasets/potential/gdml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class GDML(BaseDataset):
5454
]
5555

5656
__energy_unit__ = "kcal/mol"
57-
__distance_unit__ = "bohr"
58-
__forces_unit__ = "kcal/mol/bohr"
57+
__distance_unit__ = "ang"
58+
__forces_unit__ = "kcal/mol/ang"
5959
__links__ = {
6060
"gdb7_9.hdf5.gz": "https://zenodo.org/record/3588361/files/208.hdf5.gz",
6161
"gdb10_13.hdf5.gz": "https://zenodo.org/record/3588364/files/209.hdf5.gz",

openqdc/datasets/potential/iso_17.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class ISO17(BaseDataset):
4040
]
4141

4242
__energy_unit__ = "ev"
43-
__distance_unit__ = "bohr" # bohr
44-
__forces_unit__ = "ev/bohr"
43+
__distance_unit__ = "ang"
44+
__forces_unit__ = "ev/ang"
4545
__links__ = {"iso_17.hdf5.gz": "https://zenodo.org/record/3585907/files/216.hdf5.gz"}
4646

4747
def __smiles_converter__(self, x):

openqdc/datasets/potential/molecule3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def read_mol(mol: Chem.rdchem.Mol, energy: float) -> Dict[str, np.ndarray]:
4141
res = dict(
4242
name=np.array([smiles]),
4343
subset=np.array(["molecule3d"]),
44-
energies=np.array([energy]).astype(np.float32)[:, None],
45-
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float64),
44+
energies=np.array([energy]).astype(np.float64)[:, None],
45+
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32),
4646
n_atoms=np.array([x.shape[0]], dtype=np.int32),
4747
)
4848

openqdc/datasets/potential/qm7x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class QM7X(BaseDataset):
5757

5858
__energy_methods__ = [PotentialMethod.PBE0_DEF2_TZVP, PotentialMethod.DFT3B] # "pbe0/def2-tzvp", "dft3b"]
5959

60-
energy_target_names = ["ePBE0", "eMBD"]
60+
energy_target_names = ["ePBE0+MBD", "eDFTB+MBD"]
6161

6262
__force_mask__ = [True, True]
6363

0 commit comments

Comments
 (0)