Skip to content

Commit 74f1c19

Browse files
committed
Merged format_change + Fix NOT_DEFINED
2 parents 51a5191 + 16dcb4e commit 74f1c19

File tree

15 files changed

+421
-189
lines changed

15 files changed

+421
-189
lines changed

src/openqdc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"Transition1X": "openqdc.datasets.transition1x",
3333
}
3434

35-
_lazy_imports_mod = {"datasets": "openqdc.datamodule", "utils": "openqdc.utils"}
35+
_lazy_imports_mod = {"datasets": "openqdc.datasets", "utils": "openqdc.utils"}
3636

3737

3838
def __getattr__(name):

src/openqdc/datasets/ani.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class ANI1(BaseDataset):
4343
def root(self):
4444
return p_join(get_local_cache(), "ani")
4545

46+
def __smiles_converter__(self, x):
47+
"""util function to convert string to smiles: useful if the smiles is
48+
encoded in a different format than its display format
49+
"""
50+
return "-".join(x.decode("ascii").split("-")[:-1])
51+
4652
@property
4753
def preprocess_path(self):
4854
path = p_join(self.root, "preprocessed", self.__name__)
@@ -132,6 +138,12 @@ def _stats(self):
132138
},
133139
}
134140

141+
def __smiles_converter__(self, x):
142+
"""util function to convert string to smiles: useful if the smiles is
143+
encoded in a different format than its display format
144+
"""
145+
return x
146+
135147

136148
class ANI1X(ANI1):
137149
"""
@@ -317,3 +329,9 @@ def _stats(self):
317329

318330
def convert_forces(self, x):
319331
return super().convert_forces(x) * 0.529177249 # correct the Dataset error
332+
333+
def __smiles_converter__(self, x):
334+
"""util function to convert string to smiles: useful if the smiles is
335+
encoded in a different format than its display format
336+
"""
337+
return x

src/openqdc/datasets/base.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pickle as pkl
23
from os.path import join as p_join
34
from typing import Dict, List, Optional, Union
45

@@ -26,12 +27,11 @@
2627
dict_to_atoms,
2728
get_local_cache,
2829
load_hdf5_file,
29-
load_pkl,
3030
pull_locally,
3131
push_remote,
3232
set_cache_dir,
3333
)
34-
from openqdc.utils.molecule import atom_table
34+
from openqdc.utils.molecule import atom_table, z_to_formula
3535
from openqdc.utils.package_utils import requires_package
3636
from openqdc.utils.units import get_conversion
3737

@@ -50,7 +50,7 @@ def extract_entry(
5050

5151
res = dict(
5252
name=np.array([df["name"][i]]),
53-
subset=np.array([subset]),
53+
subset=np.array([subset if subset is not None else z_to_formula(x)]),
5454
energies=energies.reshape((1, -1)).astype(np.float32),
5555
atomic_inputs=np.concatenate((xs, positions), axis=-1, dtype=np.float32),
5656
n_atoms=np.array([x.shape[0]], dtype=np.int32),
@@ -71,8 +71,8 @@ def read_qc_archive_h5(
7171
) -> List[Dict[str, np.ndarray]]:
7272
data = load_hdf5_file(raw_path)
7373
data_t = {k2: data[k1][k2][:] for k1 in data.keys() for k2 in data[k1].keys()}
74-
n = len(data_t["molecule_id"])
7574

75+
n = len(data_t["molecule_id"])
7676
samples = [extract_entry(data_t, i, subset, energy_target_names, force_target_names) for i in tqdm(range(n))]
7777
return samples
7878

@@ -103,7 +103,7 @@ def __init__(
103103
self.data = None
104104
self._set_units(energy_unit, distance_unit)
105105
if not self.is_preprocessed():
106-
self._download()
106+
raise DatasetNotAvailableError(self.__name__)
107107
else:
108108
self.read_preprocess(overwrite_local_cache=overwrite_local_cache)
109109
self._set_isolated_atom_energies()
@@ -120,12 +120,12 @@ def _download(self):
120120
def numbers(self):
121121
if hasattr(self, "_numbers"):
122122
return self._numbers
123-
self._numbers = np.array(list(set(self.data["atomic_inputs"][..., 0])), dtype=np.int32)
123+
self._numbers = np.unique(self.data["atomic_inputs"][..., 0]).astype(np.int32)
124124
return self._numbers
125125

126126
@property
127127
def chemical_species(self):
128-
return [chemical_symbols[z] for z in self.numbers]
128+
return np.array(chemical_symbols)[self.numbers]
129129

130130
@property
131131
def energy_unit(self):
@@ -224,10 +224,11 @@ def collate_list(self, list_entries):
224224
# concatenate entries
225225
res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]}
226226

227-
csum = np.cumsum(res.pop("n_atoms"))
227+
csum = np.cumsum(res.get("n_atoms"))
228228
x = np.zeros((csum.shape[0], 2), dtype=np.int32)
229229
x[1:, 0], x[:, 1] = csum[:-1], csum
230230
res["position_idx_range"] = x
231+
231232
return res
232233

233234
def save_preprocess(self, data_dict):
@@ -241,12 +242,13 @@ def save_preprocess(self, data_dict):
241242
push_remote(local_path, overwrite=True)
242243

243244
# save smiles and subset
245+
local_path = p_join(self.preprocess_path, "props.pkl")
244246
for key in ["name", "subset"]:
245-
local_path = p_join(self.preprocess_path, f"{key}.npz")
246-
uniques, inv_indices = np.unique(data_dict[key], return_inverse=True)
247-
with open(local_path, "wb") as f:
248-
np.savez_compressed(f, uniques=uniques, inv_indices=inv_indices)
249-
push_remote(local_path)
247+
data_dict[key] = np.unique(data_dict[key], return_inverse=True)
248+
249+
with open(local_path, "wb") as f:
250+
pkl.dump(data_dict, f)
251+
push_remote(local_path, overwrite=True)
250252

251253
def read_preprocess(self, overwrite_local_cache=False):
252254
logger.info("Reading preprocessed data")
@@ -260,36 +262,29 @@ def read_preprocess(self, overwrite_local_cache=False):
260262
for key in self.data_keys:
261263
filename = p_join(self.preprocess_path, f"{key}.mmap")
262264
pull_locally(filename, overwrite=overwrite_local_cache)
263-
self.data[key] = np.memmap(
264-
filename,
265-
mode="r",
266-
dtype=self.data_types[key],
267-
).reshape(self.data_shapes[key])
265+
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(self.data_shapes[key])
266+
267+
filename = p_join(self.preprocess_path, "props.pkl")
268+
pull_locally(filename, overwrite=overwrite_local_cache)
269+
with open(filename, "rb") as f:
270+
tmp = pkl.load(f)
271+
for key in ["name", "subset", "n_atoms"]:
272+
x = tmp.pop(key)
273+
if len(x) == 2:
274+
self.data[key] = x[0][x[1]]
275+
else:
276+
self.data[key] = x
268277

269278
for key in self.data:
270279
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")
271280

272-
for key in ["props"]:
273-
filename = p_join(self.preprocess_path, f"{key}.pkl")
274-
pull_locally(filename)
275-
for key, v in load_pkl(filename).items():
276-
self.data[key] = dict()
277-
if key == "n_atoms":
278-
self.data[key] = v
279-
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")
280-
else:
281-
self.data[key]["uniques"] = v[0]
282-
self.data[key]["inv_indices"] = v[1]
283-
logger.info(f"Loaded {key}_{'uniques'} with shape {v[0].shape}, dtype { v[0].dtype}")
284-
logger.info(f"Loaded {key}_{'inv_indices'} with shape {v[1].shape}, dtype {v[1].dtype}")
285-
286281
def is_preprocessed(self):
287282
predicats = [copy_exists(p_join(self.preprocess_path, f"{key}.mmap")) for key in self.data_keys]
288-
predicats += [copy_exists(p_join(self.preprocess_path, f"{x}.pkl")) for x in ["props"]]
283+
predicats += [copy_exists(p_join(self.preprocess_path, "props.pkl"))]
289284
return all(predicats)
290285

291-
def preprocess(self):
292-
if not self.is_preprocessed():
286+
def preprocess(self, overwrite=False):
287+
if overwrite or not self.is_preprocessed():
293288
entries = self.read_raw_entries()
294289
res = self.collate_list(entries)
295290
self.save_preprocess(res)
@@ -323,7 +318,7 @@ def get_ase_atoms(self, idx: int, ext=True):
323318

324319
@requires_package("dscribe")
325320
@requires_package("datamol")
326-
def chemical_space(
321+
def soap_descriptors(
327322
self,
328323
n_samples: Optional[Union[List[int], int]] = None,
329324
return_idxs: bool = True,
@@ -368,7 +363,7 @@ def chemical_space(
368363
idxs = list(range(len(self)))
369364
elif isinstance(n_samples, int):
370365
idxs = np.random.choice(len(self), size=n_samples, replace=False)
371-
elif isinstance(n_samples, list):
366+
else: # list, set, np.ndarray
372367
idxs = n_samples
373368
datum = {}
374369
r_cut = soap_kwargs.pop("r_cut", 5.0)
@@ -401,7 +396,7 @@ def wrapper(idx):
401396
entry = self.get_ase_atoms(idx, ext=False)
402397
return soap.create(entry, centers=entry.positions)
403398

404-
descr = dm.parallelized(wrapper, idxs, progress=progress, scheduler="threads")
399+
descr = dm.parallelized(wrapper, idxs, progress=progress, scheduler="threads", n_jobs=-1)
405400
datum["soap"] = np.vstack(descr)
406401
if return_idxs:
407402
datum["idxs"] = idxs
@@ -410,6 +405,12 @@ def wrapper(idx):
410405
def __len__(self):
411406
return self.data["energies"].shape[0]
412407

408+
def __smiles_converter__(self, x):
409+
"""util function to convert string to smiles: useful if the smiles is
410+
encoded in a different format than its display format
411+
"""
412+
return x
413+
413414
def __getitem__(self, idx: int):
414415
shift = IsolatedAtomEnergyFactory.max_charge
415416
p_start, p_end = self.data["position_idx_range"][idx]
@@ -420,9 +421,9 @@ def __getitem__(self, idx: int):
420421
self.convert_distance(np.array(input[:, -3:], dtype=np.float32)),
421422
self.convert_energy(np.array(self.data["energies"][idx], dtype=np.float32)),
422423
)
423-
name = self.data["name"]["uniques"][self.data["name"]["inv_indices"][idx]]
424-
subset = self.data["subset"]["uniques"][self.data["subset"]["inv_indices"][idx]]
425-
n_atoms = self.data["n_atoms"][idx]
424+
name = self.__smiles_converter__(self.data["name"][idx])
425+
subset = self.data["subset"][idx]
426+
426427
if "forces" in self.data:
427428
forces = self.convert_forces(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))
428429
else:
@@ -436,7 +437,6 @@ def __getitem__(self, idx: int):
436437
name=name,
437438
subset=subset,
438439
forces=forces,
439-
n_atoms=n_atoms,
440440
)
441441

442442
def __str__(self):

src/openqdc/datasets/comp6.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class COMP6(BaseDataset):
3737
"pbe-d3bj/def2-tzvp",
3838
"pbe/def2-tzvp",
3939
"svwn/def2-tzvp",
40-
"wb97m-d3bj/def2-tzvp",
41-
"wb97m/def2-tzvp",
40+
# "wb97m-d3bj/def2-tzvp",
41+
# "wb97m/def2-tzvp",
4242
]
4343

4444
energy_target_names = [
@@ -49,8 +49,8 @@ class COMP6(BaseDataset):
4949
"PBE-D3M(BJ):def2-tzvp",
5050
"PBE:def2-tzvp",
5151
"SVWN:def2-tzvp",
52-
"WB97M-D3(BJ):def2-tzvp",
53-
"WB97M:def2-tzvp",
52+
# "WB97M-D3(BJ):def2-tzvp",
53+
# "WB97M:def2-tzvp",
5454
]
5555

5656
__force_methods__ = [
@@ -150,6 +150,12 @@ def _stats(self):
150150
},
151151
}
152152

153+
def __smiles_converter__(self, x):
154+
"""util function to convert string to smiles: useful if the smiles is
155+
encoded in a different format than its display format
156+
"""
157+
return "-".join(x.decode("ascii").split("_")[:-1])
158+
153159
def read_raw_entries(self):
154160
samples = []
155161
for subset in ["ani_md", "drugbank", "gdb7_9", "gdb10_13", "s66x8", "tripeptides"]:

src/openqdc/datasets/iso_17.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def _stats(self):
8585
},
8686
}
8787

88+
def __smiles_converter__(self, x):
89+
"""util function to convert string to smiles: useful if the smiles is
90+
encoded in a different format than its display format
91+
"""
92+
return "-".join(x.decode("ascii").split("_")[:-1])
93+
8894
def read_raw_entries(self):
8995
raw_path = p_join(self.root, "iso_17.h5")
9096
samples = read_qc_archive_h5(raw_path, "iso_17", self.energy_target_names, self.force_target_names)

src/openqdc/datasets/nabladft.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,46 @@
44

55
import datamol as dm
66
import numpy as np
7-
from tqdm import tqdm
7+
import pandas as pd
88

99
from openqdc.datasets.base import BaseDataset
10+
from openqdc.utils.molecule import z_to_formula
1011
from openqdc.utils.package_utils import requires_package
1112

1213

13-
def to_mol(entry) -> Dict[str, np.ndarray]:
14+
def to_mol(entry, metadata) -> Dict[str, np.ndarray]:
1415
Z, R, E, F = entry[:4]
1516
C = np.zeros_like(Z)
17+
E[0] = metadata["DFT TOTAL ENERGY"]
1618

1719
res = dict(
1820
atomic_inputs=np.concatenate((Z[:, None], C[:, None], R), axis=-1).astype(np.float32),
19-
name=np.array([""]),
21+
name=np.array([metadata["SMILES"]]),
2022
energies=E[:, None].astype(np.float32),
2123
forces=F[:, :, None].astype(np.float32),
2224
n_atoms=np.array([Z.shape[0]], dtype=np.int32),
23-
subset=np.array(["nabla"]),
25+
subset=np.array([z_to_formula(Z)]),
2426
)
2527

2628
return res
2729

2830

2931
@requires_package("nablaDFT")
30-
def read_chunk_from_db(raw_path, start_idx, stop_idx, step_size=1000):
32+
def read_chunk_from_db(raw_path, start_idx, stop_idx, labels, step_size=1000):
3133
from nablaDFT.dataset import HamiltonianDatabase
3234

3335
print(f"Loading from {start_idx} to {stop_idx}")
3436
db = HamiltonianDatabase(raw_path)
3537
idxs = list(np.arange(start_idx, stop_idx))
3638
n, s = len(idxs), step_size
3739

38-
samples = [to_mol(entry) for i in tqdm(range(0, n, s)) for entry in db[idxs[i : i + s]]]
40+
cursor = db._get_connection().cursor()
41+
data_idxs = cursor.execute("""SELECT * FROM dataset_ids WHERE id IN (""" + str(idxs)[1:-1] + ")").fetchall()
42+
c_idxs = [tuple(x[1:]) for x in data_idxs]
43+
44+
samples = [
45+
to_mol(entry, labels[c_idxs[i + j]]) for i in range(0, n, s) for j, entry in enumerate(db[idxs[i : i + s]])
46+
]
3947
return samples
4048

4149

@@ -68,12 +76,16 @@ class NablaDFT(BaseDataset):
6876
def read_raw_entries(self):
6977
from nablaDFT.dataset import HamiltonianDatabase
7078

79+
label_path = p_join(self.root, "summary.csv")
80+
df = pd.read_csv(label_path, usecols=["MOSES id", "CONFORMER id", "SMILES", "DFT TOTAL ENERGY"])
81+
labels = df.set_index(keys=["MOSES id", "CONFORMER id"]).to_dict("index")
82+
7183
raw_path = p_join(self.root, "dataset_full.db")
7284
train = HamiltonianDatabase(raw_path)
7385
n, c = len(train), 20
7486
step_size = int(np.ceil(n / os.cpu_count()))
7587

76-
fn = lambda i: read_chunk_from_db(raw_path, i * step_size, min((i + 1) * step_size, n))
88+
fn = lambda i: read_chunk_from_db(raw_path, i * step_size, min((i + 1) * step_size, n), labels=labels)
7789
samples = dm.parallelized(
7890
fn, list(range(c)), n_jobs=c, progress=False, scheduler="threads"
7991
) # don't use more than 1 job

0 commit comments

Comments
 (0)