Skip to content

Commit 18bc79c

Browse files
author
Nikhil Shenoy
committed
bug fix and simplifying interaction dataset
1 parent ed8e264 commit 18bc79c

File tree

2 files changed

+3
-18
lines changed

2 files changed

+3
-18
lines changed

openqdc/datasets/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def save_preprocess(self, data_dict):
341341
# save smiles and subset
342342
local_path = p_join(self.preprocess_path, "props.pkl")
343343

344-
# assert that required keys are present in data_dict
345-
assert all([key in self.pkl_data_keys for key in data_dict.keys()])
344+
# assert that (required) pkl keys are present in data_dict
345+
assert all([key in data_dict.keys() for key in self.pkl_data_keys])
346346

347347
# store unique and inverse indices for str-based pkl keys
348348
for key in self.pkl_data_keys:

openqdc/datasets/interaction/base.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from os.path import join as p_join
3-
from typing import Dict, List, Optional
3+
from typing import Optional
44

55
import numpy as np
66
from ase.io.extxyz import write_extxyz
@@ -23,21 +23,6 @@ def pkl_data_types(self):
2323
"n_atoms_first": np.int32,
2424
}
2525

26-
def collate_list(self, list_entries: List[Dict]):
27-
# concatenate entries
28-
res = {
29-
key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0)
30-
for key in list_entries[0]
31-
if not isinstance(list_entries[0][key], dict)
32-
}
33-
34-
csum = np.cumsum(res.get("n_atoms"))
35-
x = np.zeros((csum.shape[0], 2), dtype=np.int32)
36-
x[1:, 0], x[:, 1] = csum[:-1], csum
37-
res["position_idx_range"] = x
38-
39-
return res
40-
4126
def __getitem__(self, idx: int):
4227
shift = MAX_CHARGE
4328
p_start, p_end = self.data["position_idx_range"][idx]

0 commit comments

Comments
 (0)