diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1f04d36..9194d2c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11", "3.12"] os: ["ubuntu-latest"] runs-on: ${{ matrix.os }} diff --git a/openqdc/__init__.py b/openqdc/__init__.py index 35d6227..7e2eb2c 100644 --- a/openqdc/__init__.py +++ b/openqdc/__init__.py @@ -34,6 +34,8 @@ def get_project_root(): "SN2RXN": "openqdc.datasets.potential.sn2_rxn", "QM7X": "openqdc.datasets.potential.qm7x", "QM7X_V2": "openqdc.datasets.potential.qm7x", + "QM1B": "openqdc.datasets.potential.qm1b", + "QM1B_SMALL": "openqdc.datasets.potential.qm1b", "NablaDFT": "openqdc.datasets.potential.nabladft", "SolvatedPeptides": "openqdc.datasets.potential.solvated_peptides", "WaterClusters": "openqdc.datasets.potential.waterclusters3_30", @@ -101,6 +103,8 @@ def __dir__(): from .datasets.interaction.metcalf import Metcalf from .datasets.interaction.splinter import Splinter from .datasets.interaction.x40 import X40 + + # POTENTIAL from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2X from .datasets.potential.comp6 import COMP6 from .datasets.potential.dummy import Dummy @@ -113,6 +117,7 @@ def __dir__(): from .datasets.potential.nabladft import NablaDFT from .datasets.potential.orbnet_denali import OrbnetDenali from .datasets.potential.pcqm import PCQM_B3LYP, PCQM_PM6 + from .datasets.potential.qm1b import QM1B, QM1B_SMALL from .datasets.potential.qm7x import QM7X, QM7X_V2 from .datasets.potential.qmugs import QMugs, QMugs_V2 from .datasets.potential.revmd17 import RevMD17 diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index db04fcd..026bfd7 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -38,7 +38,12 @@ ) from openqdc.utils.package_utils import has_package, requires_package from openqdc.utils.regressor import Regressor # noqa -from openqdc.utils.units import get_conversion +from openqdc.utils.units import ( + DistanceTypeConversion, + EnergyTypeConversion, + ForceTypeConversion, + get_conversion, +) if has_package("torch"): import torch @@ -129,7 +134,7 @@ def __init__( set_cache_dir(cache_dir) # self._init_lambda_fn() self.data = None - self._original_unit = self.__energy_unit__ + self._original_unit = self.energy_unit self.recompute_statistics = recompute_statistics self.regressor_kwargs = regressor_kwargs self.transform = transform @@ -225,24 +230,27 @@ def e0s_dispatcher(self): def _convert_data(self): logger.info( f"Converting {self.__name__} data to the following units:\n\ - Energy: {self.energy_unit},\n\ - Distance: {self.distance_unit},\n\ - Forces: {self.force_unit if self.__force_methods__ else 'None'}" + Energy: {str(self.energy_unit)},\n\ + Distance: {str(self.distance_unit)},\n\ + Forces: {str(self.force_unit) if self.__force_methods__ else 'None'}" ) for key in self.data_keys: self.data[key] = self._convert_on_loading(self.data[key], key) @property def energy_unit(self): - return self.__energy_unit__ + return EnergyTypeConversion(self.__energy_unit__) @property def distance_unit(self): - return self.__distance_unit__ + return DistanceTypeConversion(self.__distance_unit__) @property def force_unit(self): - return self.__forces_unit__ + units = self.__forces_unit__.split("/") + if len(units) > 2: + units = ["/".join(units[:2]), units[-1]] + return ForceTypeConversion(tuple(units)) # < 3.12 compatibility @property def root(self): @@ -291,15 +299,15 @@ def data_shapes(self): "forces": (-1, 3, len(self.force_methods)), } - def _set_units(self, en, ds): + def _set_units(self, en: Optional[str] = None, ds: Optional[str] = None): old_en, old_ds = self.energy_unit, self.distance_unit en = en if en is not None else old_en ds = ds if ds is not None else old_ds self.set_energy_unit(en) self.set_distance_unit(ds) if self.__force_methods__: - self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit - self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__) + self._fn_forces = self.force_unit.to(str(self.energy_unit), str(self.distance_unit)) + self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit) def _set_isolated_atom_energies(self): if self.__energy_methods__ is None: @@ -308,7 +316,7 @@ def _set_isolated_atom_energies(self): f = get_conversion("hartree", self.__energy_unit__) else: # regression are calculated on the original unit of the dataset - f = get_conversion(self._original_unit, self.__energy_unit__) + f = self._original_unit.to(self.energy_unit) self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix) def convert_energy(self, x): @@ -324,17 +332,19 @@ def set_energy_unit(self, value: str): """ Set a new energy unit for the dataset. """ - old_unit = self.energy_unit + # old_unit = self.energy_unit + # self.__energy_unit__ = value + self._fn_energy = self.energy_unit.to(value) # get_conversion(old_unit, value) self.__energy_unit__ = value - self._fn_energy = get_conversion(old_unit, value) def set_distance_unit(self, value: str): """ Set a new distance unit for the dataset. """ - old_unit = self.distance_unit + # old_unit = self.distance_unit + # self.__distance_unit__ = value + self._fn_distance = self.distance_unit.to(value) # get_conversion(old_unit, value) self.__distance_unit__ = value - self._fn_distance = get_conversion(old_unit, value) def set_array_format(self, format: str): assert format in ["numpy", "torch", "jax"], f"Format {format} not supported." diff --git a/openqdc/datasets/io.py b/openqdc/datasets/io.py index cd8bfdb..1e621f7 100644 --- a/openqdc/datasets/io.py +++ b/openqdc/datasets/io.py @@ -47,8 +47,8 @@ def __init__( self.recompute_statistics = True self.refit_e0s = True self.energy_type = energy_type - self._original_unit = energy_unit self.__energy_unit__ = energy_unit + self._original_unit = self.energy_unit self.__distance_unit__ = distance_unit self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory] self.energy_target_names = ["xyz"] diff --git a/openqdc/datasets/potential/__init__.py b/openqdc/datasets/potential/__init__.py index db826a1..b59fcad 100644 --- a/openqdc/datasets/potential/__init__.py +++ b/openqdc/datasets/potential/__init__.py @@ -10,6 +10,7 @@ from .nabladft import NablaDFT from .orbnet_denali import OrbnetDenali from .pcqm import PCQM_B3LYP, PCQM_PM6 +from .qm1b import QM1B, QM1B_SMALL from .qm7x import QM7X, QM7X_V2 from .qmugs import QMugs, QMugs_V2 from .revmd17 import RevMD17 @@ -39,6 +40,8 @@ "QM7X_V2": QM7X_V2, "QMugs": QMugs, "QMugs_V2": QMugs_V2, + "QM1B": QM1B, + "QM1B_SMALL": QM1B_SMALL, "SN2RXN": SN2RXN, "SolvatedPeptides": SolvatedPeptides, "Spice": Spice, diff --git a/openqdc/datasets/potential/ani.py b/openqdc/datasets/potential/ani.py index b31d658..bcff384 100644 --- a/openqdc/datasets/potential/ani.py +++ b/openqdc/datasets/potential/ani.py @@ -210,7 +210,21 @@ class ANI1CCX_V2(ANI1CCX): class ANI2X(ANI1): - """ """ + """ + The ANI-2X dataset was constructed using active learning from modified versions of GDB-11, CheMBL, + and s66x8. It adds three new elements (F, Cl, S) resulting in 4.6 million conformers from 13k + chemical isomers, optimized using the LBFGS algorithm and labeled with ωB97X/6-31G*. + + Usage + ```python + from openqdc.datasets import ANI@X + dataset = ANI2X() + ``` + + References: + - ANI-2x: https://doi.org/10.1021/acs.jctc.0c00121 + - Github: https://github.com/aiqm/ANI1x_datasets + """ __name__ = "ani2x" __energy_unit__ = "hartree" diff --git a/openqdc/datasets/potential/qm1b.py b/openqdc/datasets/potential/qm1b.py new file mode 100644 index 0000000..5e10ed2 --- /dev/null +++ b/openqdc/datasets/potential/qm1b.py @@ -0,0 +1,157 @@ +import os +from functools import partial +from os.path import join as p_join + +import datamol as dm +import numpy as np +import pandas as pd + +from openqdc.datasets.base import BaseDataset +from openqdc.methods import PotentialMethod +from openqdc.utils.io import get_local_cache + +# fmt: off +FILE_NUM = [ + "43005175","43005205","43005208","43005211","43005214","43005223", + "43005235","43005241","43005244","43005247","43005253","43005259", + "43005265","43005268","43005271","43005274","43005277","43005280", + "43005286","43005292","43005298","43005304","43005307","43005313", + "43005319","43005322","43005325","43005331","43005337","43005343" + "43005346","43005349","43005352","43005355","43005358","43005364", + "43005370","43005406","43005409","43005415","43005418","43005421", + "43005424","43005427","43005430","43005433","43005436","43005439", + "43005442","43005448","43005454","43005457","43005460","43005463", + "43005466","43005469","43005472","43005475","43005478","43005481", + "43005484","43005487","43005490","43005496","43005499","43005502", + "43005505","43005508","43005511","43005514","43005517","43005520", + "43005523","43005526","43005532","43005538","43005544","43005547", + "43005550","43005553","43005556","43005559","43005562","43005577", + "43005580","43005583","43005589","43005592","43005598","43005601", + "43005616","43005622","43005625","43005628","43005634","43005637", + "43005646","43005649","43005658","43005661","43005676","43006159", + "43006162","43006165","43006168","43006171","43006174","43006177", + "43006180","43006186","43006207","43006210","43006213","43006219", + "43006222","43006228","43006231","43006273","43006276","43006279", + "43006282","43006288","43006294","43006303","43006318","43006324", + "43006330","43006333","43006336","43006345","43006354","43006372", + "43006381","43006384","43006390","43006396","43006405","43006408", + "43006411","43006423","43006432","43006456","43006468","43006471", + "43006477","43006486","43006489","43006492","43006498","43006501", + "43006513","43006516","43006522","43006525","43006528","43006531", + "43006534","43006537","43006543","43006546","43006576","43006579", + "43006603","43006609","43006615","43006621","43006624","43006627", + "43006630","43006633","43006639","43006645","43006651","43006654", + "43006660","43006663","43006666","43006669","43006672","43006681", + "43006690","43006696","43006699","43006711","43006717","43006738", + "43006747","43006756","43006762","43006765","43006768","43006771", + "43006777","43006780","43006786","43006789","43006795","43006798", + "43006801","43006804","43006816","43006822","43006837","43006840", + "43006846","43006855","43006858","43006861","43006864","43006867", + "43006870","43006873","43006876","43006882","43006897","43006900", + "43006903","43006909","43006912","43006927","43006930","43006933", + "43006939","43006942","43006948","43006951","43006954","43006957", + "43006966","43006969","43006978","43006984","43006993","43006996", + "43006999","43007002","43007005","43007008","43007011","43007014", + "43007017","43007032","43007035","43007041","43007044","43007047", + "43007050","43007053","43007056","43007062","43007068","43007080", + "43007098","43007110","43007119","43007122","43007125", +] +# fmt: on + + +def extract_from_row(row, file_idx=None): + smiles = row["smile"] + z = np.array(row["z"])[:, None] + c = np.zeros_like(z) + x = np.concatenate((z, c), axis=1) + positions = np.array(row["pos"]).reshape(-1, 3) + + res = dict( + name=np.array([smiles]), + subset=np.array(["qm1b"]) if file_idx is None else np.array([f"qm1b_{file_idx}"]), + energies=np.array([row["energy"]]).astype(np.float64)[:, None], + atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32), + n_atoms=np.array([x.shape[0]], dtype=np.int32), + ) + return res + + +class QM1B(BaseDataset): + """ + QM1B is a low-resolution DFT dataset generated using PySCF IPU. + It is composed of one billion training examples containing 9-11 heavy atoms. + It was created by taking 1.09M SMILES strings from the GDB-11 database and + computing molecular properties (e.g. HOMO-LUMO gap) for a set of up to 1000 + conformers per molecule at the B3LYP/STO-3G level of theory. + + Usage: + ```python + from openqdc.datasets import QM1B + dataset = QM1B() + ``` + + References: + - https://arxiv.org/pdf/2311.01135 + - https://github.com/graphcore-research/qm1b-dataset/ + """ + + __name__ = "qm1b" + + __energy_methods__ = [PotentialMethod.B3LYP_STO3G] + __force_methods__ = [] + + energy_target_names = ["b3lyp/sto-3g"] + force_target_names = [] + + __energy_unit__ = "ev" + __distance_unit__ = "bohr" + __forces_unit__ = "ev/bohr" + __links__ = { + "qm1b_validation.parquet": "https://ndownloader.figshare.com/files/43005175", + **{f"part_{i:03d}.parquet": f"https://ndownloader.figshare.com/files/{FILE_NUM[i]}" for i in range(0, 256)}, + } + + @property + def root(self): + return p_join(get_local_cache(), "qm1b") + + @property + def preprocess_path(self): + path = p_join(self.root, "preprocessed", self.__name__) + os.makedirs(path, exist_ok=True) + return path + + def read_raw_entries(self): + filenames = list(map(lambda x: p_join(self.root, f"part_{x:03d}.parquet"), list(range(0, 256)))) + [ + p_join(self.root, "qm1b_validation.parquet") + ] + + def read_entries_parallel(filename): + df = pd.read_parquet(filename) + + def extract_parallel(df, i): + return extract_from_row(df.iloc[i]) + + fn = partial(extract_parallel, df) + list_of_idxs = list(range(len(df))) + results = dm.utils.parallelized(fn, list_of_idxs, scheduler="threads", progress=False) + return results + + list_of_list = dm.utils.parallelized(read_entries_parallel, filenames, scheduler="processes", progress=True) + + return [x for xs in list_of_list for x in xs] + + +class QM1B_SMALL(QM1B): + """ + QM1B_SMALL is a subset of the QM1B dataset containing a + maximum of 15 random conformers per molecule. + + Usage: + ```python + from openqdc.datasets import QM1B_SMALL + dataset = QM1B_SMALL() + ``` + """ + + __name__ = "qm1b_small" diff --git a/openqdc/methods/enums.py b/openqdc/methods/enums.py index 6689f3a..9dff4a1 100644 --- a/openqdc/methods/enums.py +++ b/openqdc/methods/enums.py @@ -9,7 +9,7 @@ class StrEnum(str, Enum): def __str__(self): - return self.value + return self.value.lower() @unique @@ -45,6 +45,7 @@ class BasisSet(StrEnum): HA_DZ = "haDZ" HA_TZ = "haTZ" CBS_ADZ = "cbs(adz)" + STO3G = "sto-3g" GSTAR = "6-31g*" CC_PVDZ = "cc-pvdz" CC_PVTZ = "cc-pvtz" @@ -231,6 +232,7 @@ class PotentialMethod(QmMethod): # SPLIT FOR INTERACTIO ENERGIES AND FIX MD1 B1PW91_VWN5_DZP = Functional.B1PW91_VWN5, BasisSet.DZP B1PW91_VWN5_SZ = Functional.B1PW91_VWN5, BasisSet.SZ B1PW91_VWN5_TZP = Functional.B1PW91_VWN5, BasisSet.TZP + B3LYP_STO3G = Functional.B3LYP, BasisSet.STO3G # TODO: calculate e0s B3LYP_VWN5_DZP = Functional.B3LYP_VWN5, BasisSet.DZP B3LYP_VWN5_SZ = Functional.B3LYP_VWN5, BasisSet.SZ B3LYP_VWN5_TZP = Functional.B3LYP_VWN5, BasisSet.TZP diff --git a/openqdc/utils/units.py b/openqdc/utils/units.py index 12c13f9..d8613a5 100644 --- a/openqdc/utils/units.py +++ b/openqdc/utils/units.py @@ -8,6 +8,7 @@ ["ang", "nm", "bohr"] """ +from enum import Enum, unique from typing import Callable from openqdc.utils.exceptions import ConversionAlreadyDefined, ConversionNotDefinedError @@ -15,6 +16,85 @@ CONVERSION_REGISTRY = {} +# Redefined to avoid circular imports +class StrEnum(str, Enum): + def __str__(self): + return self.value.lower() + + +# Parent class for all conversion enums +class ConversionEnum(Enum): + pass + + +@unique +class EnergyTypeConversion(ConversionEnum, StrEnum): + """ + Define the possible energy units for conversion + """ + + KCAL_MOL = "kcal/mol" + KJ_MOL = "kj/mol" + HARTREE = "hartree" + EV = "ev" + MEV = "mev" + RYD = "ryd" + + def to(self, energy: "EnergyTypeConversion"): + return get_conversion(str(self), str(energy)) + + +@unique +class DistanceTypeConversion(ConversionEnum, StrEnum): + """ + Define the possible distance units for conversion + """ + + ANG = "ang" + NM = "nm" + BOHR = "bohr" + + def to(self, distance: "DistanceTypeConversion", fraction: bool = False): + return get_conversion(str(self), str(distance)) if not fraction else get_conversion(str(distance), str(self)) + + +@unique +class ForceTypeConversion(ConversionEnum): + """ + Define the possible foce units for conversion + """ + + # Name = EnergyTypeConversion, , DistanceTypeConversion + HARTREE_BOHR = EnergyTypeConversion.HARTREE, DistanceTypeConversion.BOHR + HARTREE_ANG = EnergyTypeConversion.HARTREE, DistanceTypeConversion.ANG + HARTREE_NM = EnergyTypeConversion.HARTREE, DistanceTypeConversion.NM + EV_BOHR = EnergyTypeConversion.EV, DistanceTypeConversion.BOHR + EV_ANG = EnergyTypeConversion.EV, DistanceTypeConversion.ANG + EV_NM = EnergyTypeConversion.EV, DistanceTypeConversion.NM + KCAL_MOL_BOHR = EnergyTypeConversion.KCAL_MOL, DistanceTypeConversion.BOHR + KCAL_MOL_ANG = EnergyTypeConversion.KCAL_MOL, DistanceTypeConversion.ANG + KCAL_MOL_NM = EnergyTypeConversion.KCAL_MOL, DistanceTypeConversion.NM + KJ_MOL_BOHR = EnergyTypeConversion.KJ_MOL, DistanceTypeConversion.BOHR + KJ_MOL_ANG = EnergyTypeConversion.KJ_MOL, DistanceTypeConversion.ANG + KJ_MOL_NM = EnergyTypeConversion.KJ_MOL, DistanceTypeConversion.NM + MEV_BOHR = EnergyTypeConversion.MEV, DistanceTypeConversion.BOHR + MEV_ANG = EnergyTypeConversion.MEV, DistanceTypeConversion.ANG + MEV_NM = EnergyTypeConversion.MEV, DistanceTypeConversion.NM + RYD_BOHR = EnergyTypeConversion.RYD, DistanceTypeConversion.BOHR + RYD_ANG = EnergyTypeConversion.RYD, DistanceTypeConversion.ANG + RYD_NM = EnergyTypeConversion.RYD, DistanceTypeConversion.NM + + def __init__(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion): + self.energy = energy + self.distance = distance + + def __str__(self): + return f"{self.energy}/{self.distance}" + + def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion): + return lambda x: self.distance.to(distance, fraction=True)(self.energy.to(energy)(x)) + + class Conversion: """ Conversion from one unit system to another. @@ -66,23 +146,37 @@ def get_conversion(in_unit: str, out_unit: str): Conversion("ev", "kcal/mol", lambda x: x * 23.0605) Conversion("ev", "hartree", lambda x: x * 0.0367493) Conversion("ev", "kj/mol", lambda x: x * 96.4853) -Conversion("mev", "ev", lambda x: x * 1000.0) -Conversion("ev", "mev", lambda x: x * 0.0001) +Conversion("ev", "mev", lambda x: x * 1000.0) +Conversion("mev", "ev", lambda x: x * 0.0001) +Conversion("ev", "ryd", lambda x: x * 0.07349864) # kcal/mol conversion Conversion("kcal/mol", "ev", lambda x: x * 0.0433641) Conversion("kcal/mol", "hartree", lambda x: x * 0.00159362) Conversion("kcal/mol", "kj/mol", lambda x: x * 4.184) +Conversion("kcal/mol", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("kcal/mol", "ev")(x))) +Conversion("kcal/mol", "ryd", lambda x: x * 0.00318720) # hartree conversion Conversion("hartree", "ev", lambda x: x * 27.211386246) Conversion("hartree", "kcal/mol", lambda x: x * 627.509) Conversion("hartree", "kj/mol", lambda x: x * 2625.5) +Conversion("hartree", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("hartree", "ev")(x))) +Conversion("hartree", "ryd", lambda x: x * 2.0) # kj/mol conversion Conversion("kj/mol", "ev", lambda x: x * 0.0103643) Conversion("kj/mol", "kcal/mol", lambda x: x * 0.239006) Conversion("kj/mol", "hartree", lambda x: x * 0.000380879) +Conversion("kj/mol", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("kj/mol", "ev")(x))) +Conversion("kj/mol", "ryd", lambda x: x * 0.000301318) + +# Rydberg conversion +Conversion("ryd", "ev", lambda x: x * 13.60569301) +Conversion("ryd", "kcal/mol", lambda x: x * 313.7545) +Conversion("ryd", "hartree", lambda x: x * 0.5) +Conversion("ryd", "kj/mol", lambda x: x * 1312.75) +Conversion("ryd", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("ryd", "ev")(x))) # distance conversions Conversion("bohr", "ang", lambda x: x * 0.52917721092) @@ -91,21 +185,3 @@ def get_conversion(in_unit: str, out_unit: str): Conversion("nm", "ang", lambda x: x * 10.0) Conversion("nm", "bohr", lambda x: x * 18.8973) Conversion("bohr", "nm", lambda x: x / 18.8973) - -# common forces conversion -Conversion("hartree/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "ev")(x))) -Conversion("hartree/bohr", "ev/bohr", lambda x: get_conversion("hartree", "ev")(x)) -Conversion("hartree/bohr", "kcal/mol/bohr", lambda x: get_conversion("hartree", "kcal/mol")(x)) -Conversion( - "hartree/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "kcal/mol")(x)) -) -Conversion("hartree/ang", "kcal/mol/ang", lambda x: get_conversion("hartree", "kcal/mol")(x)) -Conversion("hartree/ang", "hartree/bohr", lambda x: get_conversion("bohr", "ang")(x)) -Conversion("hartree/bohr", "hartree/ang", lambda x: get_conversion("ang", "bohr")(x)) -Conversion("kcal/mol/bohr", "hartree/bohr", lambda x: get_conversion("kcal/mol", "hartree")(x)) -Conversion("ev/ang", "hartree/ang", lambda x: get_conversion("ev", "hartree")(x)) -Conversion("ev/bohr", "hartree/bohr", lambda x: get_conversion("ev", "hartree")(x)) -Conversion("ev/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(x)) -Conversion("ev/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("ev", "kcal/mol")(x))) -Conversion("kcal/mol/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(x)) -Conversion("ev/ang", "kcal/mol/ang", lambda x: get_conversion("ev", "kcal/mol")(x)) diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..68f5037 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,520 @@ +import numpy as np +import pytest +from numpy import array, float32 + +from openqdc.datasets.potential.dummy import PredefinedDataset +from openqdc.utils.units import get_conversion + + +@pytest.fixture +def original_units(): + return { + "energy_unit": "hartree", + "distance_unit": "bohr", + "forces_unit": "hartree/bohr", + } + + +@pytest.fixture +def original_energies(): + return array([[-90.0], [-230.0], [-10.0], [-200.0], [-100.0]]) + + +@pytest.fixture +def original_e0s_first_entry(): + return array( + [ + [-37.87264507], + [-37.87264507], + [-37.87264507], + [-37.87264507], + [-37.87264507], + [-37.87264507], + [-37.87264507], + [-37.87264507], + [-54.08594143], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + [-0.49876051], + ] + ) + + +@pytest.fixture +def original_distance(): + return array( + [ + [3.67694139e00, -2.25296164e00, -1.50531638e00], + [1.17158151e00, -3.31060600e00, -6.75342798e-01], + [3.31067538e00, 2.80183315e-01, -2.45827127e00], + [-6.73973337e-02, -1.54327595e00, 1.30509722e00], + [2.51685810e00, 2.17634439e00, -4.07443404e-01], + [-1.81137574e00, 5.37082815e00, 2.74436927e00], + [-6.63455278e-02, 1.36833620e00, 6.06603682e-01], + [-9.35529292e-01, 2.69734836e00, 3.05301452e00], + [1.13202202e00, 2.88020563e00, 4.87897444e00], + [4.55532551e00, -3.49109173e00, -2.93064141e00], + [4.88521004e00, -2.21581841e00, 2.15676889e-01], + [1.19775248e00, -5.14340162e00, 5.32178394e-02], + [-1.64151995e-03, -3.47284317e00, -2.24242687e00], + [4.94823980e00, 9.86394286e-01, -3.76261663e00], + [1.55617833e00, 1.93074927e-01, -3.79791284e00], + [1.20725083e00, -1.61850047e00, 2.98583961e00], + [-2.03743386e00, -2.21026397e00, 1.38890767e00], + [3.92221451e00, 1.97753751e00, 9.90654707e-01], + [2.58457398e00, 4.29353046e00, -8.48216891e-01], + [-2.01352406e00, 6.39746761e00, 4.49058580e00], + [-3.73511076e00, 5.30047083e00, 2.06555033e00], + [-3.45413387e-01, 6.29020166e00, 1.62131798e00], + [-1.44347918e00, 1.69119561e00, -8.25058460e-01], + [-2.31007552e00, 1.52176499e00, 4.01171541e00], + [2.02086377e00, 1.13430703e00, 5.25842619e00], + [5.00191629e-01, 3.61749721e00, 6.61439800e00], + [2.58144808e00, 4.08207464e00, 4.16889763e00], + [4.40603590e00, 5.59871769e00, -1.35815811e00], + [-1.46873784e00, -3.88561106e00, -8.31218123e-01], + [1.12636681e01, 1.05836134e01, -9.79270840e00], + [-5.19884157e00, -1.06156540e00, 1.58325291e01], + [-2.70264530e00, -5.28654194e00, 4.04208899e00], + [4.95727301e00, 4.61567068e00, -3.33766818e00], + [3.83176136e00, 2.27276921e00, -3.72962642e00], + [1.53303146e00, 1.80276620e00, -2.66488409e00], + [9.06536484e00, 6.50896502e00, -9.07474422e00], + [8.58784103e00, 1.00848455e01, -6.14316130e00], + [-5.56771278e00, -3.25930190e00, 1.10719595e01], + [-2.44024372e00, -2.07635343e-01, 1.16389360e01], + [7.68177032e00, 5.01355362e00, -7.35591173e00], + [6.94557476e00, 8.64970875e00, -4.64556980e00], + [-4.81432915e00, -3.77891970e00, 8.72253704e00], + [-1.63055503e00, -7.25802898e-01, 9.20566368e00], + [9.72696686e00, 9.13173485e00, -8.38851833e00], + [-4.60693693e00, -1.41886199e00, 1.26682453e01], + [6.60912657e00, 6.07958031e00, -5.16560030e00], + [-2.92542553e00, -2.33112574e00, 7.53244352e00], + [-3.94577122e00, -1.18768215e00, 3.01625586e00], + [7.23172307e-01, -2.03086853e00, 4.16400003e00], + [-2.13760924e00, -2.77392507e00, 4.68810034e00], + [1.64781902e-02, -7.82176316e-01, -3.04831910e00], + [-2.80235934e00, -2.72733927e-01, 6.31411374e-01], + [1.36307573e00, -2.06423068e00, 1.27680647e00], + [-6.76723361e-01, -1.76777232e00, -5.23940384e-01], + [-2.63938713e00, -5.44667196e00, 2.21966410e00], + [4.94228649e00, 8.95530701e-01, -4.60446787e00], + [4.48645473e-01, 3.33815026e00, -1.70987070e00], + [9.82781410e00, 5.65525198e00, -1.07715559e01], + [9.08096695e00, 1.20195494e01, -5.54208851e00], + [-7.11611986e00, -4.28993273e00, 1.19259653e01], + [-1.55726564e00, 1.10154831e00, 1.27880898e01], + [7.14328480e00, 3.11889815e00, -7.86796093e00], + [6.45645618e00, 9.42563915e00, -2.88139629e00], + [-5.73853683e00, -5.37374210e00, 7.66268682e00], + [-2.45697454e-01, 4.17709291e-01, 8.39682102e00], + [-5.50794983e00, -2.51168823e00, 2.61494541e00], + [-4.60784054e00, 6.74198091e-01, 3.83336496e00], + [1.06903493e00, -1.14642583e-01, 4.83560658e00], + [2.25614643e00, -3.10614038e00, 5.05577421e00], + [-1.44668889e00, -3.94177079e-01, -4.41888762e00], + [1.36198246e00, -1.92607069e00, -3.95416617e00], + [-1.91955316e00, 1.44418514e00, 1.07350314e00], + [-4.27329683e00, -4.16186787e-02, -8.70908499e-01], + [2.40655351e00, -5.75643301e-01, 1.09908116e00], + [2.26591492e00, -3.81750226e00, 6.84852242e-01], + [3.67694139e00, -2.25296164e00, -1.50531638e00], + [1.17158151e00, -3.31060600e00, -6.75342798e-01], + [3.31067538e00, 2.80183315e-01, -2.45827127e00], + [-6.73973337e-02, -1.54327595e00, 1.30509722e00], + [2.51685810e00, 2.17634439e00, -4.07443404e-01], + [-1.81137574e00, 5.37082815e00, 2.74436927e00], + [-6.63455278e-02, 1.36833620e00, 6.06603682e-01], + [-9.35529292e-01, 2.69734836e00, 3.05301452e00], + [1.13202202e00, 2.88020563e00, 4.87897444e00], + [4.55532551e00, -3.49109173e00, -2.93064141e00], + [4.88521004e00, -2.21581841e00, 2.15676889e-01], + [1.19775248e00, -5.14340162e00, 5.32178394e-02], + [-1.64151995e-03, -3.47284317e00, -2.24242687e00], + [4.94823980e00, 9.86394286e-01, -3.76261663e00], + [1.55617833e00, 1.93074927e-01, -3.79791284e00], + [1.20725083e00, -1.61850047e00, 2.98583961e00], + [-2.03743386e00, -2.21026397e00, 1.38890767e00], + [3.92221451e00, 1.97753751e00, 9.90654707e-01], + [2.58457398e00, 4.29353046e00, -8.48216891e-01], + [-2.01352406e00, 6.39746761e00, 4.49058580e00], + [-3.73511076e00, 5.30047083e00, 2.06555033e00], + [-3.45413387e-01, 6.29020166e00, 1.62131798e00], + [-1.44347918e00, 1.69119561e00, -8.25058460e-01], + [-2.31007552e00, 1.52176499e00, 4.01171541e00], + [2.02086377e00, 1.13430703e00, 5.25842619e00], + [5.00191629e-01, 3.61749721e00, 6.61439800e00], + [2.58144808e00, 4.08207464e00, 4.16889763e00], + [-7.18970537e00, 6.47051430e00, -1.78789961e00], + [-7.66641283e00, 2.76283574e00, 5.69243860e00], + [1.46481133e00, -5.84496880e00, 2.75619698e00], + [-6.28430033e00, 5.75309563e00, 1.54569700e-01], + [-6.63203239e00, 3.83197808e00, 4.00907278e00], + [-1.00620240e-01, -4.05598497e00, 2.16277957e00], + [7.26440954e00, 3.27383471e00, -6.03921747e00], + [6.41222048e00, 8.87307286e-01, 7.54087210e00], + [5.44775534e00, 4.88935995e00, -5.64345312e00], + [7.18479490e00, 8.33229184e-01, -4.99916458e00], + [6.24100637e00, 1.94993556e00, 5.12957048e00], + [5.03426027e00, -1.32769716e00, 8.24502945e00], + [3.54736328e00, 4.31157160e00, -3.91554999e00], + [5.25473738e00, 1.54854096e-02, -3.52882981e00], + [-2.17661500e00, 3.27098155e00, -1.01027799e00], + [4.27832508e00, 1.09049726e00, 3.65444875e00], + [3.10736060e00, -2.15409160e00, 6.79668665e00], + [-2.84234262e00, 1.12899649e00, 3.69459701e00], + [-4.14098644e00, 4.06633806e00, 6.45772159e-01], + [-4.49631691e00, 2.81941366e00, 2.87427187e00], + [3.37165761e00, 1.76894200e00, -3.04554462e00], + [-3.90353531e-01, 1.53518534e00, -3.07558417e-01], + [2.65353012e00, -1.01803887e00, 4.32171679e00], + [-1.03777003e00, 3.19495380e-01, 2.05414677e00], + [-7.69457006e00, 5.50280476e00, 2.35955858e00], + [1.44933832e00, 5.78019023e-01, -1.78190553e00], + [3.61046970e-01, -1.62705970e00, 3.00323367e00], + [-2.58820343e00, -4.81352997e00, 6.28754616e-01], + [8.82366371e00, 3.86827493e00, -7.45671415e00], + [7.98199987e00, 1.51466060e00, 8.68747044e00], + [5.69073820e00, 6.80793619e00, -6.29577160e00], + [8.71504879e00, -6.02105021e-01, -5.52924776e00], + [7.80299187e00, 3.20323944e00, 4.44802618e00], + [5.41853046e00, -2.57642674e00, 9.97229004e00], + [2.14713717e00, 5.74735880e00, -3.44167089e00], + [5.33957195e00, -1.92729735e00, -2.85384536e00], + [-2.27019310e00, 4.06121683e00, -2.77049541e00], + [4.16562462e00, 2.09750867e00, 1.83321524e00], + [1.80853200e00, -3.48297858e00, 7.66277742e00], + [-3.23054957e00, 5.41086674e-01, 5.65746450e00], + [-9.66177845e00, 5.57600975e00, 2.17735410e00], + [1.70237637e00, -1.25289285e00, -1.96954167e00], + [-4.03682232e00, -5.00325108e00, 1.99573731e00], + [-2.22336721e00, -6.53946781e00, -1.70130104e-01], + [-3.00589347e00, -3.25132108e00, -6.11233473e-01], + [-7.18970537e00, 6.47051430e00, -1.78789961e00], + [-7.66641283e00, 2.76283574e00, 5.69243860e00], + [1.46481133e00, -5.84496880e00, 2.75619698e00], + [-6.28430033e00, 5.75309563e00, 1.54569700e-01], + [-6.63203239e00, 3.83197808e00, 4.00907278e00], + [-1.00620240e-01, -4.05598497e00, 2.16277957e00], + [7.26440954e00, 3.27383471e00, -6.03921747e00], + [6.41222048e00, 8.87307286e-01, 7.54087210e00], + [5.44775534e00, 4.88935995e00, -5.64345312e00], + [7.18479490e00, 8.33229184e-01, -4.99916458e00], + [6.24100637e00, 1.94993556e00, 5.12957048e00], + [5.03426027e00, -1.32769716e00, 8.24502945e00], + [3.54736328e00, 4.31157160e00, -3.91554999e00], + [5.25473738e00, 1.54854096e-02, -3.52882981e00], + [-2.17661500e00, 3.27098155e00, -1.01027799e00], + [4.27832508e00, 1.09049726e00, 3.65444875e00], + [3.10736060e00, -2.15409160e00, 6.79668665e00], + [-2.84234262e00, 1.12899649e00, 3.69459701e00], + [-4.14098644e00, 4.06633806e00, 6.45772159e-01], + [-4.49631691e00, 2.81941366e00, 2.87427187e00], + [3.37165761e00, 1.76894200e00, -3.04554462e00], + [-3.90353531e-01, 1.53518534e00, -3.07558417e-01], + [2.65353012e00, -1.01803887e00, 4.32171679e00], + [-1.03777003e00, 3.19495380e-01, 2.05414677e00], + [-7.69457006e00, 5.50280476e00, 2.35955858e00], + [1.44933832e00, 5.78019023e-01, -1.78190553e00], + [3.61046970e-01, -1.62705970e00, 3.00323367e00], + [-2.58820343e00, -4.81352997e00, 6.28754616e-01], + [8.82366371e00, 3.86827493e00, -7.45671415e00], + [7.98199987e00, 1.51466060e00, 8.68747044e00], + [5.69073820e00, 6.80793619e00, -6.29577160e00], + [8.71504879e00, -6.02105021e-01, -5.52924776e00], + [7.80299187e00, 3.20323944e00, 4.44802618e00], + [5.41853046e00, -2.57642674e00, 9.97229004e00], + [2.14713717e00, 5.74735880e00, -3.44167089e00], + [5.33957195e00, -1.92729735e00, -2.85384536e00], + [-2.27019310e00, 4.06121683e00, -2.77049541e00], + [4.16562462e00, 2.09750867e00, 1.83321524e00], + [1.80853200e00, -3.48297858e00, 7.66277742e00], + [-3.23054957e00, 5.41086674e-01, 5.65746450e00], + [-9.66177845e00, 5.57600975e00, 2.17735410e00], + [1.70237637e00, -1.25289285e00, -1.96954167e00], + [-4.03682232e00, -5.00325108e00, 1.99573731e00], + [-2.22336721e00, -6.53946781e00, -1.70130104e-01], + [-3.00589347e00, -3.25132108e00, -6.11233473e-01], + ], + dtype=float32, + ) + + +@pytest.fixture +def original_forces(): + return array( + [ + [ + [2.0098940e-02, -4.4371959e-02, 2.5288060e-02], + [3.0883900e-03, 3.4925319e-02, 1.9593900e-02], + [-1.8585989e-02, 6.3869603e-02, -7.7826567e-02], + [1.4071250e-02, 9.9647399e-03, -2.5321390e-02], + [-7.1518999e-03, 2.2885300e-02, -1.0252580e-02], + [-3.9557400e-03, -2.3914000e-03, -3.4818000e-03], + [8.9494903e-03, -2.7339499e-02, 2.3479421e-02], + [-1.1799960e-02, -5.7527999e-04, -7.2162002e-03], + [5.6334689e-02, -1.8624000e-03, 3.2000460e-02], + [-1.4379700e-03, 5.6493902e-03, 7.8962799e-03], + [2.2008100e-03, -2.1426000e-03, -1.2427660e-02], + [9.3885399e-03, -2.9421590e-02, 1.6655169e-02], + [-2.4729479e-02, 2.3152300e-03, -3.1581681e-02], + [-2.8887330e-02, -1.0241980e-02, 3.0818630e-02], + [3.2719430e-02, 2.6927399e-03, 1.5098770e-02], + [-2.1272870e-02, -1.2256420e-02, -1.8856700e-03], + [9.4753504e-03, 5.6752702e-03, 2.1391720e-02], + [1.9571841e-02, -3.5972199e-03, 2.5518680e-02], + [-3.6662901e-03, -3.2077029e-02, -6.9814702e-03], + [4.3959701e-03, 1.2205200e-03, 1.1675710e-02], + [1.8410700e-03, 8.3344998e-03, -1.8301290e-02], + [-1.4716200e-02, 8.6473804e-03, -7.3494101e-03], + [-1.2910850e-02, 5.2068601e-03, -9.4032800e-03], + [-6.5030898e-03, -1.0604300e-03, -2.2946999e-03], + [-1.2682850e-02, 2.1387359e-02, -5.3574000e-03], + [8.9004301e-03, -8.3064800e-03, -2.2739170e-02], + [-2.2720771e-02, -1.7126599e-02, 1.2999990e-02], + [3.7794700e-03, 3.0895749e-02, 2.8142899e-02], + [-8.8798188e-02, -2.0182565e-01, -4.2870991e-02], + [-8.7965000e-03, -1.0801400e-02, 5.5787200e-03], + [-1.1793530e-02, -3.0176001e-04, 1.3800600e-03], + [8.7435301e-03, -1.3199900e-02, 1.7136870e-02], + [2.1160210e-02, 5.8726329e-02, 4.8154001e-03], + [-5.7599049e-02, -4.8749380e-02, 1.0423710e-02], + [-9.4849197e-03, -1.0181380e-02, -2.8085450e-02], + [1.9647470e-02, 3.5608239e-02, 2.2441249e-02], + [1.5643651e-02, 1.0229980e-02, -2.1429140e-02], + [-4.8281200e-02, -1.9374500e-03, 7.0114747e-02], + [-3.9701089e-02, -2.5513770e-02, 3.3872999e-04], + [-9.2138899e-03, 3.0794051e-02, -7.7717099e-03], + [4.1673161e-02, -1.5535990e-02, -2.2655200e-02], + [5.7828198e-03, -1.7883500e-02, -7.2353020e-02], + [-3.1263210e-02, 5.9046003e-04, -1.8511759e-02], + [-4.4043671e-02, -5.1590171e-02, 2.3608640e-02], + [6.2813938e-02, 3.3906000e-03, -4.2455290e-02], + [-5.6233299e-03, -4.6967301e-03, -6.1163399e-03], + [1.4563670e-02, -2.2854021e-02, 1.2023320e-02], + [-9.2339199e-03, 3.2203200e-03, 1.0124270e-02], + [-2.2176741e-02, -1.2340700e-03, -1.6926041e-02], + [3.0724001e-03, 9.9830804e-03, 2.2748690e-02], + [5.8463689e-02, 9.0885743e-02, 1.1359200e-02], + [1.3506500e-03, 2.6117770e-02, -3.7040271e-02], + [-1.3757030e-02, -4.9668878e-02, 3.7864018e-02], + [-1.0984430e-02, 9.2085890e-02, 2.0700229e-02], + [6.2799300e-03, 6.1981301e-03, -4.6103401e-03], + [6.2065101e-03, -2.6552090e-02, -1.9207031e-02], + [1.7216250e-02, -1.2792430e-02, -9.5830997e-03], + [-6.9876999e-04, 3.6903000e-03, 7.8682002e-04], + [-4.5436099e-03, -1.2266140e-02, -5.8457400e-03], + [-6.5407000e-04, -8.1518097e-03, -5.7564098e-03], + [2.3151031e-02, 2.5158480e-02, 1.7081439e-02], + [5.8197700e-03, -3.8674099e-03, -1.4005800e-03], + [-1.6544331e-02, 5.6582098e-03, 1.5035230e-02], + [1.1229410e-02, 2.2075661e-02, 1.2264130e-02], + [2.7446371e-02, 1.5480930e-02, -3.8228000e-03], + [3.0165601e-03, 1.4834640e-02, 6.4639798e-03], + [1.1818300e-03, -2.9177111e-02, 2.1789600e-03], + [-6.3775801e-03, 1.8332400e-03, -2.5221901e-03], + [-1.6519601e-02, -7.3949001e-03, 2.9081400e-03], + [-1.6859289e-02, -7.3338100e-03, 1.3897040e-02], + [1.1155440e-02, -2.9904990e-02, -1.8528100e-03], + [5.4696398e-03, 3.5557911e-02, 7.4013998e-04], + [1.7369010e-02, -9.2426001e-04, 1.2370120e-02], + [7.9970531e-02, 8.4416710e-02, -1.4149260e-02], + [6.4725999e-04, 6.9413399e-03, 2.3965500e-03], + [2.0098940e-02, -4.4371959e-02, 2.5288060e-02], + [3.0883900e-03, 3.4925319e-02, 1.9593900e-02], + [-1.8585989e-02, 6.3869603e-02, -7.7826567e-02], + [1.4071250e-02, 9.9647399e-03, -2.5321390e-02], + [-7.1518999e-03, 2.2885300e-02, -1.0252580e-02], + [-3.9557400e-03, -2.3914000e-03, -3.4818000e-03], + [8.9494903e-03, -2.7339499e-02, 2.3479421e-02], + [-1.1799960e-02, -5.7527999e-04, -7.2162002e-03], + [5.6334689e-02, -1.8624000e-03, 3.2000460e-02], + [-1.4379700e-03, 5.6493902e-03, 7.8962799e-03], + [2.2008100e-03, -2.1426000e-03, -1.2427660e-02], + [9.3885399e-03, -2.9421590e-02, 1.6655169e-02], + [-2.4729479e-02, 2.3152300e-03, -3.1581681e-02], + [-2.8887330e-02, -1.0241980e-02, 3.0818630e-02], + [3.2719430e-02, 2.6927399e-03, 1.5098770e-02], + [-2.1272870e-02, -1.2256420e-02, -1.8856700e-03], + [9.4753504e-03, 5.6752702e-03, 2.1391720e-02], + [1.9571841e-02, -3.5972199e-03, 2.5518680e-02], + [-3.6662901e-03, -3.2077029e-02, -6.9814702e-03], + [4.3959701e-03, 1.2205200e-03, 1.1675710e-02], + [1.8410700e-03, 8.3344998e-03, -1.8301290e-02], + [-1.4716200e-02, 8.6473804e-03, -7.3494101e-03], + [-1.2910850e-02, 5.2068601e-03, -9.4032800e-03], + [-6.5030898e-03, -1.0604300e-03, -2.2946999e-03], + [-1.2682850e-02, 2.1387359e-02, -5.3574000e-03], + [8.9004301e-03, -8.3064800e-03, -2.2739170e-02], + [-2.2720771e-02, -1.7126599e-02, 1.2999990e-02], + [9.7695002e-03, 2.0127570e-02, -1.2980280e-02], + [-2.7942570e-02, 9.2920999e-04, 4.6880729e-02], + [-6.1051190e-02, 6.8795778e-02, -2.5638610e-02], + [-2.8667970e-02, -3.5424151e-02, -1.9555001e-03], + [-4.7071699e-02, -2.3389090e-02, 6.5093199e-03], + [5.0008319e-02, -8.8082343e-02, 2.3161080e-02], + [8.6634941e-02, -4.7883440e-02, -5.4378539e-02], + [2.3322999e-02, -2.7841071e-02, 1.0350000e-05], + [-4.0369540e-02, 5.6349069e-02, 4.0032700e-02], + [5.0734740e-02, -2.4287131e-02, -3.1415019e-02], + [7.2482699e-03, -1.4482470e-02, 3.1026891e-02], + [2.9213680e-02, 1.4185420e-02, 2.9727720e-02], + [-2.2642011e-02, -2.4963060e-02, 6.6630901e-03], + [2.5026379e-02, 8.9355102e-03, -1.6467299e-02], + [-2.8570071e-02, 2.9602069e-02, 3.3173598e-02], + [-2.7714010e-02, -2.3182770e-02, -3.6700711e-02], + [-3.2320909e-02, 1.1975440e-02, -5.3553499e-02], + [-8.2072103e-03, -2.1984169e-02, 3.1608630e-02], + [5.7557411e-02, -2.4288220e-02, -4.2348061e-02], + [3.9058849e-02, 5.5638582e-02, -3.9625522e-02], + [-8.7162002e-04, 2.4371840e-02, -8.3007198e-03], + [-5.6681771e-02, -2.3770170e-02, 4.5389850e-02], + [2.8463850e-02, 2.7666900e-02, 5.1420581e-02], + [1.3180960e-02, 3.0710639e-02, -6.3042301e-03], + [-4.9958970e-02, 3.9361179e-02, 1.8747999e-04], + [-4.7463900e-03, 1.2972400e-02, -6.2818299e-03], + [4.1373100e-02, -5.0648849e-02, -1.3128620e-02], + [4.3331161e-02, 6.5269530e-02, 4.0079061e-02], + [-2.4878871e-02, -1.7338211e-02, 3.0210350e-02], + [-1.7828700e-03, -1.8141600e-03, 5.5730799e-03], + [-8.0383504e-03, 7.1810000e-04, -4.2441301e-03], + [-2.1777401e-02, 2.6862159e-02, 9.7052502e-03], + [-2.7382219e-02, -6.1902702e-03, 6.8063498e-03], + [-1.6110300e-03, 3.4320548e-02, -2.1525670e-02], + [1.5957700e-03, -9.7004296e-03, -1.9352000e-04], + [-9.3391901e-03, 7.4654100e-03, 2.8138000e-03], + [8.6294999e-03, 1.9359371e-02, -4.5523290e-02], + [-9.0313796e-03, -1.4879520e-02, 8.4688598e-03], + [2.7567500e-03, -6.1969701e-03, -1.4553310e-02], + [1.5009880e-02, -1.3113780e-02, -1.5555550e-02], + [3.0122399e-02, 5.7058702e-03, 8.4186699e-03], + [-6.7246798e-03, -2.2097429e-02, 1.5180030e-02], + [-1.7451471e-02, 2.4796999e-04, 1.2653390e-02], + [9.5762303e-03, -3.6768749e-02, -2.4743490e-02], + [-7.7048899e-03, -3.2854499e-03, -1.0304260e-02], + [9.7695002e-03, 2.0127570e-02, -1.2980280e-02], + [-2.7942570e-02, 9.2920999e-04, 4.6880729e-02], + [-6.1051190e-02, 6.8795778e-02, -2.5638610e-02], + [-2.8667970e-02, -3.5424151e-02, -1.9555001e-03], + [-4.7071699e-02, -2.3389090e-02, 6.5093199e-03], + [5.0008319e-02, -8.8082343e-02, 2.3161080e-02], + [8.6634941e-02, -4.7883440e-02, -5.4378539e-02], + [2.3322999e-02, -2.7841071e-02, 1.0350000e-05], + [-4.0369540e-02, 5.6349069e-02, 4.0032700e-02], + [5.0734740e-02, -2.4287131e-02, -3.1415019e-02], + [7.2482699e-03, -1.4482470e-02, 3.1026891e-02], + [2.9213680e-02, 1.4185420e-02, 2.9727720e-02], + [-2.2642011e-02, -2.4963060e-02, 6.6630901e-03], + [2.5026379e-02, 8.9355102e-03, -1.6467299e-02], + [-2.8570071e-02, 2.9602069e-02, 3.3173598e-02], + [-2.7714010e-02, -2.3182770e-02, -3.6700711e-02], + [-3.2320909e-02, 1.1975440e-02, -5.3553499e-02], + [-8.2072103e-03, -2.1984169e-02, 3.1608630e-02], + [5.7557411e-02, -2.4288220e-02, -4.2348061e-02], + [3.9058849e-02, 5.5638582e-02, -3.9625522e-02], + [-8.7162002e-04, 2.4371840e-02, -8.3007198e-03], + [-5.6681771e-02, -2.3770170e-02, 4.5389850e-02], + [2.8463850e-02, 2.7666900e-02, 5.1420581e-02], + [1.3180960e-02, 3.0710639e-02, -6.3042301e-03], + [-4.9958970e-02, 3.9361179e-02, 1.8747999e-04], + [-4.7463900e-03, 1.2972400e-02, -6.2818299e-03], + [4.1373100e-02, -5.0648849e-02, -1.3128620e-02], + [4.3331161e-02, 6.5269530e-02, 4.0079061e-02], + [-2.4878871e-02, -1.7338211e-02, 3.0210350e-02], + [-1.7828700e-03, -1.8141600e-03, 5.5730799e-03], + [-8.0383504e-03, 7.1810000e-04, -4.2441301e-03], + [-2.1777401e-02, 2.6862159e-02, 9.7052502e-03], + [-2.7382219e-02, -6.1902702e-03, 6.8063498e-03], + [-1.6110300e-03, 3.4320548e-02, -2.1525670e-02], + [1.5957700e-03, -9.7004296e-03, -1.9352000e-04], + [-9.3391901e-03, 7.4654100e-03, 2.8138000e-03], + [8.6294999e-03, 1.9359371e-02, -4.5523290e-02], + [-9.0313796e-03, -1.4879520e-02, 8.4688598e-03], + [2.7567500e-03, -6.1969701e-03, -1.4553310e-02], + [1.5009880e-02, -1.3113780e-02, -1.5555550e-02], + [3.0122399e-02, 5.7058702e-03, 8.4186699e-03], + [-6.7246798e-03, -2.2097429e-02, 1.5180030e-02], + [-1.7451471e-02, 2.4796999e-04, 1.2653390e-02], + [9.5762303e-03, -3.6768749e-02, -2.4743490e-02], + [-7.7048899e-03, -3.2854499e-03, -1.0304260e-02], + ] + ], + dtype=float32, + ) + + +def test_no_conversion(original_energies, original_distance, original_forces, original_e0s_first_entry): + ds = PredefinedDataset() + np.testing.assert_almost_equal(ds.data["energies"], original_energies) + np.testing.assert_almost_equal(ds.data["atomic_inputs"][:, 2:], original_distance) + np.testing.assert_almost_equal(ds.data["forces"].reshape(1, 192, 3), original_forces) + np.testing.assert_almost_equal(ds[0].e0, original_e0s_first_entry) + + +def test_same_conversion( + original_energies, original_distance, original_forces, original_e0s_first_entry, original_units +): + ds = PredefinedDataset(distance_unit=original_units["distance_unit"], energy_unit=original_units["energy_unit"]) + np.testing.assert_almost_equal(ds.data["energies"], original_energies) + np.testing.assert_almost_equal(ds.data["atomic_inputs"][:, 2:], original_distance) + np.testing.assert_almost_equal(ds.data["forces"].reshape(1, 192, 3), original_forces) + np.testing.assert_almost_equal(ds[0].e0, original_e0s_first_entry) + + +def test_energy_conversion( + original_energies, original_distance, original_forces, original_e0s_first_entry, original_units +): + ds = PredefinedDataset(energy_unit="ev") + + en_conversion_fn = get_conversion(original_units["energy_unit"], "ev") + ds_conversion_fn = get_conversion(original_units["distance_unit"], original_units["distance_unit"]) + frcs_conversion_fn = lambda x: (get_conversion(original_units["energy_unit"], "ev")(x)) + + np.testing.assert_almost_equal(ds.data["energies"], en_conversion_fn(original_energies)) + np.testing.assert_almost_equal(ds.data["atomic_inputs"][:, 2:], ds_conversion_fn(original_distance)) + np.testing.assert_almost_equal(ds.data["forces"].reshape(1, 192, 3), frcs_conversion_fn(original_forces)) + np.testing.assert_almost_equal(ds[0].e0, en_conversion_fn(original_e0s_first_entry)) + + +def test_distance_conversion( + original_energies, original_distance, original_forces, original_e0s_first_entry, original_units +): + ds = PredefinedDataset(distance_unit="ang") + en_conversion_fn = get_conversion(original_units["energy_unit"], original_units["energy_unit"]) + ds_conversion_fn = get_conversion(original_units["distance_unit"], "ang") + frcs_conversion_fn = lambda x: (get_conversion("ang", original_units["distance_unit"])(x)) + + np.testing.assert_almost_equal(ds.data["energies"], en_conversion_fn(original_energies)) + np.testing.assert_almost_equal(ds.data["atomic_inputs"][:, 2:], ds_conversion_fn(original_distance)) + np.testing.assert_almost_equal(ds.data["forces"].reshape(1, 192, 3), frcs_conversion_fn(original_forces)) + np.testing.assert_almost_equal(ds[0].e0, en_conversion_fn(original_e0s_first_entry)) + + +def test_force_conversion( + original_energies, original_distance, original_forces, original_e0s_first_entry, original_units +): + ds = PredefinedDataset(energy_unit="kcal/mol", distance_unit="ang") + en_conversion_fn = get_conversion(original_units["energy_unit"], "kcal/mol") + ds_conversion_fn = get_conversion(original_units["distance_unit"], "ang") + frcs_conversion_fn = lambda x: (get_conversion("ang", original_units["distance_unit"])(en_conversion_fn(x))) + + np.testing.assert_almost_equal(ds.data["energies"], en_conversion_fn(original_energies)) + np.testing.assert_almost_equal(ds.data["atomic_inputs"][:, 2:], ds_conversion_fn(original_distance)) + np.testing.assert_almost_equal(ds.data["forces"].reshape(1, 192, 3), frcs_conversion_fn(original_forces)) + np.testing.assert_almost_equal(ds[0].e0, en_conversion_fn(original_e0s_first_entry), decimal=4) diff --git a/tests/test_filedataset.py b/tests/test_filedataset.py index 8defc7f..d0d500d 100644 --- a/tests/test_filedataset.py +++ b/tests/test_filedataset.py @@ -1,3 +1,4 @@ +import os from io import StringIO import numpy as np @@ -5,6 +6,7 @@ from openqdc.datasets.io import XYZDataset from openqdc.methods.enums import PotentialMethod +from openqdc.utils.io import get_local_cache from openqdc.utils.package_utils import has_package if has_package("torch"): @@ -20,6 +22,14 @@ } +@pytest.fixture(autouse=True) +def clean_before_run(): + # start by removing any cached data + cache_dir = get_local_cache() + os.system(f"rm -rf {cache_dir}/XYZDataset") + yield + + @pytest.fixture def xyz_filelike(): xyz_str = """3