Skip to content

Commit e969b54

Browse files
committed
update L7 and X40 to use python base yaml package
1 parent 802b70b commit e969b54

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

openqdc/datasets/interaction/L7.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,53 @@
22
from typing import Dict, List
33

44
import numpy as np
5+
import yaml
56
from loguru import logger
6-
from ruamel.yaml import YAML
77

88
from openqdc.datasets.interaction import BaseInteractionDataset
99
from openqdc.utils.molecule import atom_table
1010

1111

12+
class DataItemYAMLObj:
13+
def __init__(self, name, shortname, geometry, reference_value, setup, group, tags):
14+
self.name = name
15+
self.shortname = shortname
16+
self.geometry = geometry
17+
self.reference_value = reference_value
18+
self.setup = setup
19+
self.group = group
20+
self.tags = tags
21+
22+
23+
class DataSetYAMLObj:
24+
def __init__(self, name, references, text, method_energy, groups_by, groups, global_setup):
25+
self.name = name
26+
self.references = references
27+
self.text = text
28+
self.method_energy = method_energy
29+
self.groups_by = groups_by
30+
self.groups = groups
31+
self.global_setup = global_setup
32+
33+
34+
def data_item_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
35+
"""Construct an employee."""
36+
return DataItemYAMLObj(**loader.construct_mapping(node))
37+
38+
39+
def dataset_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
40+
"""Construct an employee."""
41+
return DataSetYAMLObj(**loader.construct_mapping(node))
42+
43+
44+
def get_loader():
45+
"""Add constructors to PyYAML loader."""
46+
loader = yaml.SafeLoader
47+
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", data_item_constructor)
48+
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetDescription", dataset_constructor)
49+
return loader
50+
51+
1252
class L7(BaseInteractionDataset):
1353
"""
1454
The L7 interaction energy dataset as described in:
@@ -43,23 +83,22 @@ def read_raw_entries(self) -> List[Dict]:
4383
yaml_fpath = os.path.join(self.root, "l7.yaml")
4484
logger.info(f"Reading L7 interaction data from {self.root}")
4585
yaml_file = open(yaml_fpath, "r")
46-
yaml = YAML()
4786
data = []
48-
data_dict = yaml.load(yaml_file)
49-
charge0 = int(data_dict["description"]["global_setup"]["molecule_a"]["charge"])
50-
charge1 = int(data_dict["description"]["global_setup"]["molecule_b"]["charge"])
87+
data_dict = yaml.load(yaml_file, Loader=get_loader())
88+
charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"])
89+
charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"])
5190

5291
for idx, item in enumerate(data_dict["items"]):
5392
energies = []
54-
name = np.array([item["shortname"]])
55-
fname = item["geometry"].split(":")[1]
56-
energies.append(item["reference_value"])
93+
name = np.array([item.shortname])
94+
fname = item.geometry.split(":")[1]
95+
energies.append(item.reference_value)
5796
xyz_file = open(os.path.join(self.root, f"{fname}.xyz"), "r")
5897
lines = list(map(lambda x: x.strip().split(), xyz_file.readlines()))
5998
lines.pop(1)
6099
n_atoms = np.array([int(lines[0][0])], dtype=np.int32)
61-
n_atoms_first = np.array([int(item["setup"]["molecule_a"]["selection"].split("-")[1])], dtype=np.int32)
62-
subset = np.array([item["group"]])
100+
n_atoms_first = np.array([int(item.setup["molecule_a"]["selection"].split("-")[1])], dtype=np.int32)
101+
subset = np.array([item.group])
63102
energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
64103
energies = np.array([energies], dtype=np.float32)
65104
pos = np.array(lines[1:])[:, 1:].astype(np.float32)

openqdc/datasets/interaction/X40.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Dict, List
33

44
import numpy as np
5+
import yaml
56
from loguru import logger
6-
from ruamel.yaml import YAML
77

88
from openqdc.datasets.interaction import BaseInteractionDataset
9+
from openqdc.datasets.interaction.L7 import get_loader
910
from openqdc.utils.molecule import atom_table
1011

1112

@@ -41,23 +42,22 @@ def read_raw_entries(self) -> List[Dict]:
4142
yaml_fpath = os.path.join(self.root, "x40.yaml")
4243
logger.info(f"Reading X40 interaction data from {self.root}")
4344
yaml_file = open(yaml_fpath, "r")
44-
yaml = YAML()
4545
data = []
46-
data_dict = yaml.load(yaml_file)
47-
charge0 = int(data_dict["description"]["global_setup"]["molecule_a"]["charge"])
48-
charge1 = int(data_dict["description"]["global_setup"]["molecule_b"]["charge"])
46+
data_dict = yaml.load(yaml_file, Loader=get_loader())
47+
charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"])
48+
charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"])
4949

5050
for idx, item in enumerate(data_dict["items"]):
5151
energies = []
52-
name = np.array([item["shortname"]])
53-
energies.append(float(item["reference_value"]))
54-
xyz_file = open(os.path.join(self.root, f"{item['shortname']}.xyz"), "r")
52+
name = np.array([item.shortname])
53+
energies.append(float(item.reference_value))
54+
xyz_file = open(os.path.join(self.root, f"{item.shortname}.xyz"), "r")
5555
lines = list(map(lambda x: x.strip().split(), xyz_file.readlines()))
5656
setup = lines.pop(1)
5757
n_atoms = np.array([int(lines[0][0])], dtype=np.int32)
5858
n_atoms_first = setup[0].split("-")[1]
5959
n_atoms_first = np.array([int(n_atoms_first)], dtype=np.int32)
60-
subset = np.array([item["group"]])
60+
subset = np.array([item.group])
6161
energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
6262
energies = np.array([energies], dtype=np.float32)
6363
pos = np.array(lines[1:])[:, 1:].astype(np.float32)

0 commit comments

Comments
 (0)