|
2 | 2 | from typing import Dict, List
|
3 | 3 |
|
4 | 4 | import numpy as np
|
| 5 | +import yaml |
5 | 6 | from loguru import logger
|
6 |
| -from ruamel.yaml import YAML |
7 | 7 |
|
8 | 8 | from openqdc.datasets.interaction import BaseInteractionDataset
|
9 | 9 | from openqdc.utils.molecule import atom_table
|
10 | 10 |
|
11 | 11 |
|
| 12 | +class DataItemYAMLObj: |
| 13 | + def __init__(self, name, shortname, geometry, reference_value, setup, group, tags): |
| 14 | + self.name = name |
| 15 | + self.shortname = shortname |
| 16 | + self.geometry = geometry |
| 17 | + self.reference_value = reference_value |
| 18 | + self.setup = setup |
| 19 | + self.group = group |
| 20 | + self.tags = tags |
| 21 | + |
| 22 | + |
| 23 | +class DataSetYAMLObj: |
| 24 | + def __init__(self, name, references, text, method_energy, groups_by, groups, global_setup): |
| 25 | + self.name = name |
| 26 | + self.references = references |
| 27 | + self.text = text |
| 28 | + self.method_energy = method_energy |
| 29 | + self.groups_by = groups_by |
| 30 | + self.groups = groups |
| 31 | + self.global_setup = global_setup |
| 32 | + |
| 33 | + |
| 34 | +def data_item_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): |
| 35 | + """Construct an employee.""" |
| 36 | + return DataItemYAMLObj(**loader.construct_mapping(node)) |
| 37 | + |
| 38 | + |
| 39 | +def dataset_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): |
| 40 | + """Construct an employee.""" |
| 41 | + return DataSetYAMLObj(**loader.construct_mapping(node)) |
| 42 | + |
| 43 | + |
| 44 | +def get_loader(): |
| 45 | + """Add constructors to PyYAML loader.""" |
| 46 | + loader = yaml.SafeLoader |
| 47 | + loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", data_item_constructor) |
| 48 | + loader.add_constructor("!ruby/object:ProtocolDataset::DataSetDescription", dataset_constructor) |
| 49 | + return loader |
| 50 | + |
| 51 | + |
12 | 52 | class L7(BaseInteractionDataset):
|
13 | 53 | """
|
14 | 54 | The L7 interaction energy dataset as described in:
|
@@ -43,23 +83,22 @@ def read_raw_entries(self) -> List[Dict]:
|
43 | 83 | yaml_fpath = os.path.join(self.root, "l7.yaml")
|
44 | 84 | logger.info(f"Reading L7 interaction data from {self.root}")
|
45 | 85 | yaml_file = open(yaml_fpath, "r")
|
46 |
| - yaml = YAML() |
47 | 86 | data = []
|
48 |
| - data_dict = yaml.load(yaml_file) |
49 |
| - charge0 = int(data_dict["description"]["global_setup"]["molecule_a"]["charge"]) |
50 |
| - charge1 = int(data_dict["description"]["global_setup"]["molecule_b"]["charge"]) |
| 87 | + data_dict = yaml.load(yaml_file, Loader=get_loader()) |
| 88 | + charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"]) |
| 89 | + charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"]) |
51 | 90 |
|
52 | 91 | for idx, item in enumerate(data_dict["items"]):
|
53 | 92 | energies = []
|
54 |
| - name = np.array([item["shortname"]]) |
55 |
| - fname = item["geometry"].split(":")[1] |
56 |
| - energies.append(item["reference_value"]) |
| 93 | + name = np.array([item.shortname]) |
| 94 | + fname = item.geometry.split(":")[1] |
| 95 | + energies.append(item.reference_value) |
57 | 96 | xyz_file = open(os.path.join(self.root, f"{fname}.xyz"), "r")
|
58 | 97 | lines = list(map(lambda x: x.strip().split(), xyz_file.readlines()))
|
59 | 98 | lines.pop(1)
|
60 | 99 | n_atoms = np.array([int(lines[0][0])], dtype=np.int32)
|
61 |
| - n_atoms_first = np.array([int(item["setup"]["molecule_a"]["selection"].split("-")[1])], dtype=np.int32) |
62 |
| - subset = np.array([item["group"]]) |
| 100 | + n_atoms_first = np.array([int(item.setup["molecule_a"]["selection"].split("-")[1])], dtype=np.int32) |
| 101 | + subset = np.array([item.group]) |
63 | 102 | energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
|
64 | 103 | energies = np.array([energies], dtype=np.float32)
|
65 | 104 | pos = np.array(lines[1:])[:, 1:].astype(np.float32)
|
|
0 commit comments