1
1
import os
2
+ import pickle as pkl
2
3
from os .path import join as p_join
3
4
from typing import Dict , List , Optional , Union
4
5
26
27
dict_to_atoms ,
27
28
get_local_cache ,
28
29
load_hdf5_file ,
29
- load_pkl ,
30
30
pull_locally ,
31
31
push_remote ,
32
32
set_cache_dir ,
33
33
)
34
- from openqdc .utils .molecule import atom_table
34
+ from openqdc .utils .molecule import atom_table , z_to_formula
35
35
from openqdc .utils .package_utils import requires_package
36
36
from openqdc .utils .units import get_conversion
37
37
@@ -50,7 +50,7 @@ def extract_entry(
50
50
51
51
res = dict (
52
52
name = np .array ([df ["name" ][i ]]),
53
- subset = np .array ([subset ]),
53
+ subset = np .array ([subset if subset is not None else z_to_formula ( x ) ]),
54
54
energies = energies .reshape ((1 , - 1 )).astype (np .float32 ),
55
55
atomic_inputs = np .concatenate ((xs , positions ), axis = - 1 , dtype = np .float32 ),
56
56
n_atoms = np .array ([x .shape [0 ]], dtype = np .int32 ),
@@ -71,8 +71,8 @@ def read_qc_archive_h5(
71
71
) -> List [Dict [str , np .ndarray ]]:
72
72
data = load_hdf5_file (raw_path )
73
73
data_t = {k2 : data [k1 ][k2 ][:] for k1 in data .keys () for k2 in data [k1 ].keys ()}
74
- n = len (data_t ["molecule_id" ])
75
74
75
+ n = len (data_t ["molecule_id" ])
76
76
samples = [extract_entry (data_t , i , subset , energy_target_names , force_target_names ) for i in tqdm (range (n ))]
77
77
return samples
78
78
@@ -103,7 +103,7 @@ def __init__(
103
103
self .data = None
104
104
self ._set_units (energy_unit , distance_unit )
105
105
if not self .is_preprocessed ():
106
- self ._download ( )
106
+ raise DatasetNotAvailableError ( self .__name__ )
107
107
else :
108
108
self .read_preprocess (overwrite_local_cache = overwrite_local_cache )
109
109
self ._set_isolated_atom_energies ()
@@ -120,12 +120,12 @@ def _download(self):
120
120
def numbers (self ):
121
121
if hasattr (self , "_numbers" ):
122
122
return self ._numbers
123
- self ._numbers = np .array ( list ( set ( self .data ["atomic_inputs" ][..., 0 ])), dtype = np .int32 )
123
+ self ._numbers = np .unique ( self .data ["atomic_inputs" ][..., 0 ]). astype ( np .int32 )
124
124
return self ._numbers
125
125
126
126
@property
127
127
def chemical_species (self ):
128
- return [ chemical_symbols [ z ] for z in self .numbers ]
128
+ return np . array ( chemical_symbols )[ self .numbers ]
129
129
130
130
@property
131
131
def energy_unit (self ):
@@ -224,10 +224,11 @@ def collate_list(self, list_entries):
224
224
# concatenate entries
225
225
res = {key : np .concatenate ([r [key ] for r in list_entries if r is not None ], axis = 0 ) for key in list_entries [0 ]}
226
226
227
- csum = np .cumsum (res .pop ("n_atoms" ))
227
+ csum = np .cumsum (res .get ("n_atoms" ))
228
228
x = np .zeros ((csum .shape [0 ], 2 ), dtype = np .int32 )
229
229
x [1 :, 0 ], x [:, 1 ] = csum [:- 1 ], csum
230
230
res ["position_idx_range" ] = x
231
+
231
232
return res
232
233
233
234
def save_preprocess (self , data_dict ):
@@ -241,12 +242,13 @@ def save_preprocess(self, data_dict):
241
242
push_remote (local_path , overwrite = True )
242
243
243
244
# save smiles and subset
245
+ local_path = p_join (self .preprocess_path , "props.pkl" )
244
246
for key in ["name" , "subset" ]:
245
- local_path = p_join ( self . preprocess_path , f" { key } .npz" )
246
- uniques , inv_indices = np . unique ( data_dict [ key ], return_inverse = True )
247
- with open (local_path , "wb" ) as f :
248
- np . savez_compressed ( f , uniques = uniques , inv_indices = inv_indices )
249
- push_remote (local_path )
247
+ data_dict [ key ] = np . unique ( data_dict [ key ], return_inverse = True )
248
+
249
+ with open (local_path , "wb" ) as f :
250
+ pkl . dump ( data_dict , f )
251
+ push_remote (local_path , overwrite = True )
250
252
251
253
def read_preprocess (self , overwrite_local_cache = False ):
252
254
logger .info ("Reading preprocessed data" )
@@ -260,36 +262,29 @@ def read_preprocess(self, overwrite_local_cache=False):
260
262
for key in self .data_keys :
261
263
filename = p_join (self .preprocess_path , f"{ key } .mmap" )
262
264
pull_locally (filename , overwrite = overwrite_local_cache )
263
- self .data [key ] = np .memmap (
264
- filename ,
265
- mode = "r" ,
266
- dtype = self .data_types [key ],
267
- ).reshape (self .data_shapes [key ])
265
+ self .data [key ] = np .memmap (filename , mode = "r" , dtype = self .data_types [key ]).reshape (self .data_shapes [key ])
266
+
267
+ filename = p_join (self .preprocess_path , "props.pkl" )
268
+ pull_locally (filename , overwrite = overwrite_local_cache )
269
+ with open (filename , "rb" ) as f :
270
+ tmp = pkl .load (f )
271
+ for key in ["name" , "subset" , "n_atoms" ]:
272
+ x = tmp .pop (key )
273
+ if len (x ) == 2 :
274
+ self .data [key ] = x [0 ][x [1 ]]
275
+ else :
276
+ self .data [key ] = x
268
277
269
278
for key in self .data :
270
279
logger .info (f"Loaded { key } with shape { self .data [key ].shape } , dtype { self .data [key ].dtype } " )
271
280
272
- for key in ["props" ]:
273
- filename = p_join (self .preprocess_path , f"{ key } .pkl" )
274
- pull_locally (filename )
275
- for key , v in load_pkl (filename ).items ():
276
- self .data [key ] = dict ()
277
- if key == "n_atoms" :
278
- self .data [key ] = v
279
- logger .info (f"Loaded { key } with shape { self .data [key ].shape } , dtype { self .data [key ].dtype } " )
280
- else :
281
- self .data [key ]["uniques" ] = v [0 ]
282
- self .data [key ]["inv_indices" ] = v [1 ]
283
- logger .info (f"Loaded { key } _{ 'uniques' } with shape { v [0 ].shape } , dtype { v [0 ].dtype } " )
284
- logger .info (f"Loaded { key } _{ 'inv_indices' } with shape { v [1 ].shape } , dtype { v [1 ].dtype } " )
285
-
286
281
def is_preprocessed (self ):
287
282
predicats = [copy_exists (p_join (self .preprocess_path , f"{ key } .mmap" )) for key in self .data_keys ]
288
- predicats += [copy_exists (p_join (self .preprocess_path , f" { x } .pkl" )) for x in [ "props" ] ]
283
+ predicats += [copy_exists (p_join (self .preprocess_path , "props .pkl" ))]
289
284
return all (predicats )
290
285
291
- def preprocess (self ):
292
- if not self .is_preprocessed ():
286
+ def preprocess (self , overwrite = False ):
287
+ if overwrite or not self .is_preprocessed ():
293
288
entries = self .read_raw_entries ()
294
289
res = self .collate_list (entries )
295
290
self .save_preprocess (res )
@@ -323,7 +318,7 @@ def get_ase_atoms(self, idx: int, ext=True):
323
318
324
319
@requires_package ("dscribe" )
325
320
@requires_package ("datamol" )
326
- def chemical_space (
321
+ def soap_descriptors (
327
322
self ,
328
323
n_samples : Optional [Union [List [int ], int ]] = None ,
329
324
return_idxs : bool = True ,
@@ -368,7 +363,7 @@ def chemical_space(
368
363
idxs = list (range (len (self )))
369
364
elif isinstance (n_samples , int ):
370
365
idxs = np .random .choice (len (self ), size = n_samples , replace = False )
371
- elif isinstance ( n_samples , list ):
366
+ else : # list, set, np.ndarray
372
367
idxs = n_samples
373
368
datum = {}
374
369
r_cut = soap_kwargs .pop ("r_cut" , 5.0 )
@@ -401,7 +396,7 @@ def wrapper(idx):
401
396
entry = self .get_ase_atoms (idx , ext = False )
402
397
return soap .create (entry , centers = entry .positions )
403
398
404
- descr = dm .parallelized (wrapper , idxs , progress = progress , scheduler = "threads" )
399
+ descr = dm .parallelized (wrapper , idxs , progress = progress , scheduler = "threads" , n_jobs = - 1 )
405
400
datum ["soap" ] = np .vstack (descr )
406
401
if return_idxs :
407
402
datum ["idxs" ] = idxs
@@ -410,6 +405,12 @@ def wrapper(idx):
410
405
def __len__ (self ):
411
406
return self .data ["energies" ].shape [0 ]
412
407
408
+ def __smiles_converter__ (self , x ):
409
+ """util function to convert string to smiles: useful if the smiles is
410
+ encoded in a different format than its display format
411
+ """
412
+ return x
413
+
413
414
def __getitem__ (self , idx : int ):
414
415
shift = IsolatedAtomEnergyFactory .max_charge
415
416
p_start , p_end = self .data ["position_idx_range" ][idx ]
@@ -420,9 +421,9 @@ def __getitem__(self, idx: int):
420
421
self .convert_distance (np .array (input [:, - 3 :], dtype = np .float32 )),
421
422
self .convert_energy (np .array (self .data ["energies" ][idx ], dtype = np .float32 )),
422
423
)
423
- name = self .data [ "name" ][ "uniques" ][ self .data ["name" ]["inv_indices" ][ idx ]]
424
- subset = self .data ["subset" ]["uniques" ][ self . data [ "subset" ][ "inv_indices" ][ idx ] ]
425
- n_atoms = self . data [ "n_atoms" ][ idx ]
424
+ name = self .__smiles_converter__ ( self .data ["name" ][idx ])
425
+ subset = self .data ["subset" ][idx ]
426
+
426
427
if "forces" in self .data :
427
428
forces = self .convert_forces (np .array (self .data ["forces" ][p_start :p_end ], dtype = np .float32 ))
428
429
else :
@@ -436,7 +437,6 @@ def __getitem__(self, idx: int):
436
437
name = name ,
437
438
subset = subset ,
438
439
forces = forces ,
439
- n_atoms = n_atoms ,
440
440
)
441
441
442
442
def __str__ (self ):
0 commit comments