From e999f14739e903c0c5c6bac4fa265bb5f2523c35 Mon Sep 17 00:00:00 2001 From: Marvin Friede <51965259+marvinfriede@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:50:15 -0500 Subject: [PATCH] Refactor integral code (#162) --- .github/workflows/ubuntu-pytorch-1.yaml | 11 +- .github/workflows/ubuntu.yaml | 9 +- docs/source/03_for_developers/errors.rst | 15 + examples/limitation_xitorch.py | 13 +- examples/profiling/batch-vs-seq-nicotine.py | 9 +- examples/profiling/importing.py | 2 + examples/run-all.sh | 19 + setup.cfg | 8 +- src/dxtb/_src/calculators/config/cache.py | 20 + .../properties/vibration/analysis.py | 7 +- src/dxtb/_src/calculators/types/__init__.py | 15 +- src/dxtb/_src/calculators/types/analytical.py | 27 +- src/dxtb/_src/calculators/types/autograd.py | 71 +-- src/dxtb/_src/calculators/types/base.py | 40 +- src/dxtb/_src/calculators/types/decorators.py | 1 - src/dxtb/_src/calculators/types/energy.py | 60 ++- src/dxtb/_src/calculators/types/numerical.py | 15 +- src/dxtb/_src/cli/driver.py | 29 +- .../classicals/dispersion/__init__.py | 1 - .../components/classicals/dispersion/d3.py | 3 +- .../components/classicals/dispersion/d4.py | 3 +- .../components/classicals/halogen/__init__.py | 6 +- .../interactions/coulomb/secondorder.py | 8 +- .../interactions/coulomb/thirdorder.py | 8 +- .../components/interactions/solvation/alpb.py | 8 +- .../components/interactions/solvation/born.py | 8 +- src/dxtb/_src/constants/labels/method.py | 10 + src/dxtb/_src/constants/labels/scf.py | 36 ++ .../exlibs/xitorch/_core/editable_module.py | 13 +- .../_src/exlibs/xitorch/_utils/decorators.py | 56 --- src/dxtb/_src/integral/abc.py | 85 ++++ src/dxtb/_src/integral/base.py | 252 ++++++---- src/dxtb/_src/integral/container.py | 459 +++++++++--------- src/dxtb/_src/integral/driver/__init__.py | 3 +- src/dxtb/_src/integral/driver/factory.py | 125 +++++ src/dxtb/_src/integral/driver/libcint/base.py | 45 ++ .../integral/driver/libcint/base_driver.py | 128 ----- .../driver/libcint/base_implementation.py | 161 ------ .../_src/integral/driver/libcint/dipole.py | 49 +- .../_src/integral/driver/libcint/driver.py | 74 ++- .../_src/integral/driver/libcint/multipole.py | 42 +- .../_src/integral/driver/libcint/overlap.py | 95 ++-- .../integral/driver/libcint/quadrupole.py | 246 +--------- src/dxtb/_src/integral/driver/manager.py | 132 +++++ .../_src/integral/driver/pytorch/__init__.py | 4 + src/dxtb/_src/integral/driver/pytorch/base.py | 45 ++ .../integral/driver/pytorch/base_driver.py | 146 ------ .../driver/pytorch/base_implementation.py | 69 --- .../_src/integral/driver/pytorch/dipole.py | 109 +++++ .../_src/integral/driver/pytorch/driver.py | 111 ++++- .../driver/pytorch/impls/md/explicit.py | 10 +- .../driver/pytorch/impls/md/recursion.py | 2 +- .../integral/driver/pytorch/impls/md/trafo.py | 4 +- .../_src/integral/driver/pytorch/overlap.py | 29 +- .../integral/driver/pytorch/quadrupole.py | 109 +++++ src/dxtb/_src/integral/factory.py | 214 +++++--- src/dxtb/_src/integral/types/__init__.py | 4 +- src/dxtb/_src/integral/types/base.py | 243 ---------- src/dxtb/_src/integral/types/dipole.py | 80 +-- src/dxtb/_src/integral/types/h0.py | 67 --- src/dxtb/_src/integral/types/overlap.py | 38 +- src/dxtb/_src/integral/types/quadrupole.py | 277 +++++++++-- src/dxtb/_src/integral/utils.py | 34 ++ src/dxtb/_src/integral/wrappers.py | 95 ++-- src/dxtb/_src/typing/project.py | 3 +- src/dxtb/_src/utils/misc.py | 2 +- src/dxtb/_src/xtb/abc.py | 94 ++++ src/dxtb/_src/xtb/base.py | 104 ++-- src/dxtb/_src/xtb/gfn1.py | 29 +- src/dxtb/_src/xtb/gfn2.py | 4 +- src/dxtb/calculators.py | 9 + src/dxtb/components/base.py | 1 + src/dxtb/components/coulomb.py | 7 +- src/dxtb/components/dispersion.py | 6 +- src/dxtb/components/halogen.py | 5 +- src/dxtb/components/repulsion.py | 5 +- src/dxtb/components/solvation.py | 5 +- src/dxtb/config.py | 9 + src/dxtb/integrals/__init__.py | 4 +- src/dxtb/integrals/factories.py | 29 ++ src/dxtb/integrals/types.py | 9 +- src/dxtb/typing.py | 8 + test/conftest.py | 3 + test/test_a_memory_leak/test_higher_deriv.py | 10 +- test/test_a_memory_leak/test_repulsion.py | 12 +- test/test_a_memory_leak/test_scf.py | 24 +- .../test_cache/test_integrals.py | 88 ++++ .../test_cache/test_properties.py | 38 +- test/test_calculator/test_general.py | 13 +- test/test_coulomb/test_es2_atom.py | 12 +- test/test_coulomb/test_es2_general.py | 4 +- test/test_coulomb/test_es2_shell.py | 12 +- test/test_coulomb/test_es3.py | 4 +- test/test_coulomb/test_grad_atom.py | 24 +- test/test_coulomb/test_grad_atom_param.py | 10 +- test/test_coulomb/test_grad_atom_pos.py | 26 +- test/test_coulomb/test_grad_shell.py | 24 +- test/test_coulomb/test_grad_shell_param.py | 14 +- test/test_coulomb/test_grad_shell_pos.py | 32 +- test/test_dispersion/test_d3.py | 18 +- test/test_dispersion/test_d4.py | 8 +- test/test_dispersion/test_grad_pos.py | 35 +- test/test_dispersion/test_hess.py | 22 +- test/test_external/test_field.py | 2 + test/test_halogen/test_grad_pos.py | 56 +-- test/test_halogen/test_hess.py | 12 +- test/test_hamiltonian/skip_test_grad.py | 18 +- test/test_hamiltonian/test_base.py | 69 +++ test/test_hamiltonian/test_grad_pos.py | 30 +- test/test_hamiltonian/test_h0.py | 7 +- test/test_indexhelper/test_extra.py | 3 +- test/test_integrals/test_driver/__init__.py | 16 + .../test_driver/test_factory.py | 70 +++ .../test_manager.py} | 95 ++-- .../test_driver/test_pytorch.py | 168 +++++++ test/test_integrals/test_factory.py | 191 ++++++++ test/test_integrals/test_general.py | 37 +- test/test_integrals/test_libcint.py | 111 ++--- test/test_integrals/test_pytorch.py | 13 +- test/test_integrals/test_types.py | 36 +- test/test_integrals/test_wrappers.py | 6 +- test/test_interaction/test_grad.py | 20 +- test/test_libcint/test_gradcheck.py | 12 +- test/test_libcint/test_overlap_grad.py | 8 +- test/test_multipole/todo_test_dipole_grad.py | 6 +- test/test_overlap/test_grad_pos.py | 34 +- test/test_overlap/test_gradient_grad_pos.py | 8 +- test/test_param/test_param.py | 2 +- test/test_properties/test_forces.py | 8 +- test/test_properties/test_hessian.py | 8 +- test/test_properties/test_vibration.py | 8 +- test/test_properties/todo_test_quadrupole.py | 18 +- test/test_repulsion/test_grad_pos.py | 52 +- test/test_repulsion/test_hess.py | 15 +- test/test_scf/skip_test_grad_pos.py | 8 +- test/test_scf/test_charged.py | 6 +- test/test_scf/test_grad.py | 46 +- test/test_scf/test_guess_grad.py | 16 +- test/test_scf/test_hess.py | 18 +- .../test_grad_pos_withfield.py | 16 +- test/test_singlepoint/test_hess.py | 10 +- test/test_solvation/test_born.py | 16 +- test/test_solvation/test_grad.py | 12 +- tox.ini | 11 +- 144 files changed, 3551 insertions(+), 2630 deletions(-) create mode 100755 examples/run-all.sh delete mode 100644 src/dxtb/_src/exlibs/xitorch/_utils/decorators.py create mode 100644 src/dxtb/_src/integral/abc.py create mode 100644 src/dxtb/_src/integral/driver/factory.py create mode 100644 src/dxtb/_src/integral/driver/libcint/base.py delete mode 100644 src/dxtb/_src/integral/driver/libcint/base_driver.py delete mode 100644 src/dxtb/_src/integral/driver/libcint/base_implementation.py create mode 100644 src/dxtb/_src/integral/driver/manager.py create mode 100644 src/dxtb/_src/integral/driver/pytorch/base.py delete mode 100644 src/dxtb/_src/integral/driver/pytorch/base_driver.py delete mode 100644 src/dxtb/_src/integral/driver/pytorch/base_implementation.py create mode 100644 src/dxtb/_src/integral/driver/pytorch/dipole.py create mode 100644 src/dxtb/_src/integral/driver/pytorch/quadrupole.py delete mode 100644 src/dxtb/_src/integral/types/base.py delete mode 100644 src/dxtb/_src/integral/types/h0.py create mode 100644 src/dxtb/_src/integral/utils.py create mode 100644 src/dxtb/_src/xtb/abc.py create mode 100644 src/dxtb/integrals/factories.py create mode 100644 test/test_calculator/test_cache/test_integrals.py create mode 100644 test/test_hamiltonian/test_base.py create mode 100644 test/test_integrals/test_driver/__init__.py create mode 100644 test/test_integrals/test_driver/test_factory.py rename test/test_integrals/{test_driver.py => test_driver/test_manager.py} (54%) create mode 100644 test/test_integrals/test_driver/test_pytorch.py create mode 100644 test/test_integrals/test_factory.py diff --git a/.github/workflows/ubuntu-pytorch-1.yaml b/.github/workflows/ubuntu-pytorch-1.yaml index 28d68ff4d..e35c52e29 100644 --- a/.github/workflows/ubuntu-pytorch-1.yaml +++ b/.github/workflows/ubuntu-pytorch-1.yaml @@ -47,13 +47,12 @@ jobs: exclude: # Check latest versions here: https://download.pytorch.org/whl/torch/ # - # PyTorch now fully supports Python=<3.11 - # see: https://github.com/pytorch/pytorch/issues/86566 + # PyTorch issues: + # 3.11: https://github.com/pytorch/pytorch/issues/86566 + # 3.12: https://github.com/pytorch/pytorch/issues/110436 + # 3.13: https://github.com/pytorch/pytorch/issues/130249 # - # PyTorch does now support Python 3.12 (Linux) for 2.2.0 and newer - # see: https://github.com/pytorch/pytorch/issues/110436 - # - # PyTorch<1.13.0 does only support Python=<3.10 + # PyTorch<1.13.0 does only support Python<3.11 (Linux) - python-version: "3.11" torch-version: "1.11.0" - python-version: "3.11" diff --git a/.github/workflows/ubuntu.yaml b/.github/workflows/ubuntu.yaml index 9a8be185d..262879478 100644 --- a/.github/workflows/ubuntu.yaml +++ b/.github/workflows/ubuntu.yaml @@ -48,11 +48,12 @@ jobs: exclude: # Check latest versions here: https://download.pytorch.org/whl/torch/ # - # PyTorch fully supports Python=<3.11 - # see: https://github.com/pytorch/pytorch/issues/86566 + # PyTorch issues: + # 3.11: https://github.com/pytorch/pytorch/issues/86566 + # 3.12: https://github.com/pytorch/pytorch/issues/110436 + # 3.13: https://github.com/pytorch/pytorch/issues/130249 # - # PyTorch supports Python 3.12 (Linux) for 2.2.0 and newer - # see: https://github.com/pytorch/pytorch/issues/110436 + # PyTorch<2.2.0 does only support Python<3.12 (all platforms) - python-version: "3.12" torch-version: "2.0.1" - python-version: "3.12" diff --git a/docs/source/03_for_developers/errors.rst b/docs/source/03_for_developers/errors.rst index ce912b6f6..4f186fce7 100644 --- a/docs/source/03_for_developers/errors.rst +++ b/docs/source/03_for_developers/errors.rst @@ -50,3 +50,18 @@ the error is actually raised because the mixer in `xitorch` produces a Apparently, this occurs if the convergence criteria are too strict in single precision. The solution is to increase the convergence criteria to more than 1e-6. + + +RuntimeError: clone is not supported by NestedIntSymNode +-------------------------------------------------------- + +This is a bug in PyTorch 2.3.0 and 2.3.1 (see +`PyTorch #128607 <`__). +To avoid this error, manually import `torch._dynamo` in the code. For example: + +.. code-block:: python + + from tad_mctc._version import __tversion__ + + if __tversion__ in ((2, 3, 0), (2, 3, 1)): + import torch._dynamo diff --git a/examples/limitation_xitorch.py b/examples/limitation_xitorch.py index 284261e49..eba25a0a1 100644 --- a/examples/limitation_xitorch.py +++ b/examples/limitation_xitorch.py @@ -20,21 +20,30 @@ from pathlib import Path import torch +from tad_mctc._version import __tversion__ from tad_mctc.io import read import dxtb from dxtb.typing import DD +if __tversion__ in ((2, 3, 0), (2, 3, 1)): + import torch._dynamo + + dd: DD = {"device": torch.device("cpu"), "dtype": torch.double} f = Path(__file__).parent / "molecules" / "lih.xyz" numbers, positions = read.read_from_path(f, **dd) charge = read.read_chrg_from_path(f, **dd) -opts = {"verbosity": 3, "scf_mode": "nonpure"} +opts = {"verbosity": 0, "scf_mode": "nonpure"} ###################################################################### calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd) pos = positions.clone().requires_grad_(True) -hess = calc.hessian(pos, chrg=charge, use_functorch=True) + +try: + calc.hessian(pos, chrg=charge, use_functorch=True) +except RuntimeError as e: + print(f"RuntimeError:\n{str(e)}") diff --git a/examples/profiling/batch-vs-seq-nicotine.py b/examples/profiling/batch-vs-seq-nicotine.py index 9ef06fe04..90ffde7a1 100644 --- a/examples/profiling/batch-vs-seq-nicotine.py +++ b/examples/profiling/batch-vs-seq-nicotine.py @@ -29,7 +29,14 @@ dxtb.timer.cuda_sync = False # read molecule from file -f = Path(__file__).parent / "molecules" / "nicotine.xyz" +p = Path(__file__).parent + +if "molecules" not in [x.name for x in p.iterdir()]: + p = p.parent + +f = p / "molecules" / "nicotine.xyz" + + numbers, positions = read.read_from_path(f, **dd) charge = read.read_chrg_from_path(f, **dd) diff --git a/examples/profiling/importing.py b/examples/profiling/importing.py index 52b35a2b0..dc008a44b 100644 --- a/examples/profiling/importing.py +++ b/examples/profiling/importing.py @@ -47,3 +47,5 @@ print("dxtb", t2 - t1) print("Param", t3 - t2) print("scipy", t4 - t3) + +del scipy, torch, GFN1_XTB diff --git a/examples/run-all.sh b/examples/run-all.sh new file mode 100755 index 000000000..3b0930f8f --- /dev/null +++ b/examples/run-all.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +set -e + +# search recursively for all python files +for example in $(find . -name "*.py"); do + title="Running $example" + + # match the length of the title + line=$(printf '=%.0s' $(seq 1 ${#title})) + + echo "$line" + echo "$title" + echo "$line" + + python3 "$example" + + printf "\n\n" +done diff --git a/setup.cfg b/setup.cfg index d5dd850c0..d291616f2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,10 +41,10 @@ install_requires = numpy<2 pydantic>=2.0.0 scipy - tad-dftd3>=0.2.2 - tad-dftd4>=0.1.0 - tad-libcint>=0.0.1 - tad-mctc>=0.1.5 + tad-dftd3>=0.3.0 + tad-dftd4>=0.2.0 + tad-libcint>=0.1.0 + tad-mctc>=0.2.0 tad-multicharge tomli tomli-w diff --git a/src/dxtb/_src/calculators/config/cache.py b/src/dxtb/_src/calculators/config/cache.py index 0982d68ac..0239d5ecb 100644 --- a/src/dxtb/_src/calculators/config/cache.py +++ b/src/dxtb/_src/calculators/config/cache.py @@ -62,6 +62,26 @@ class ConfigCacheStore: potential: bool """Whether to store the potential matrix.""" + def set(self, key: str, value: bool) -> None: + """ + Set configuration options using keyword arguments. + + Parameters + ---------- + key : str + The configuration key. + value : bool + The configuration value. + + Example + ------- + config.set("hcore", True) + """ + if not hasattr(self, key): + raise ValueError(f"Unknown configuration key: {key}") + + setattr(self, key, value) + class ConfigCache: """ diff --git a/src/dxtb/_src/calculators/properties/vibration/analysis.py b/src/dxtb/_src/calculators/properties/vibration/analysis.py index b66eee474..71123b463 100644 --- a/src/dxtb/_src/calculators/properties/vibration/analysis.py +++ b/src/dxtb/_src/calculators/properties/vibration/analysis.py @@ -112,8 +112,9 @@ def to_unit(self, value: Literal["freqs", "modes"], unit: str) -> Tensor: if value == "freqs": return self._convert(self.freqs, unit, self.converter_freqs) + # TODO: Implement conversion for normal modes (required?) # if value == "modes": - # return self._convert(self.modes, unit, self.converter_modes) + # return self._convert(self.modes, unit, self.converter_modes) raise ValueError(f"Unsupported value for conversion: {value}") @@ -143,7 +144,7 @@ def _get_rotational_modes(mass: Tensor, positions: Tensor): _, paxes = storch.eighb(im) # make z-axis rotation vector with smallest moment of inertia - # w = torch.flip(w, [-1]) + # _ = torch.flip(_, [-1]) paxes = torch.flip(paxes, [-1]) ex, ey, ez = paxes.mT @@ -288,7 +289,7 @@ def vib_analysis( # by 2π (ω = √λ / 2π) in order to immediately get frequencies in # Hartree: E = hbar * ω with hbar = 1 in atomic units. Dividing by 2π # effectively converts from angular frequency (ω) to the frequency in - # cycles per second (ν, Hz), which would requires the following + # cycles per second (ν, Hz), which would require the following # conversion to cm^-1: 1e-2 / units.CODATA.c / units.AU2SECOND. sgn = torch.sign(force_const_au) freqs_au = torch.sqrt(torch.abs(force_const_au)) * sgn diff --git a/src/dxtb/_src/calculators/types/__init__.py b/src/dxtb/_src/calculators/types/__init__.py index b31b0f177..fcf92e486 100644 --- a/src/dxtb/_src/calculators/types/__init__.py +++ b/src/dxtb/_src/calculators/types/__init__.py @@ -63,7 +63,14 @@ class **inherits from all types**, i.e., it provides the energy and properties if we need to differentiate for multiple properties at once (e.g., Hessian and dipole moment for IR spectra). Hence, the default is ``use_functorch=False``. """ -from .analytical import * -from .autograd import * -from .energy import * -from .numerical import * +from .analytical import AnalyticalCalculator +from .autograd import AutogradCalculator +from .energy import EnergyCalculator +from .numerical import NumericalCalculator + +__all__ = [ + "AnalyticalCalculator", + "AutogradCalculator", + "EnergyCalculator", + "NumericalCalculator", +] diff --git a/src/dxtb/_src/calculators/types/analytical.py b/src/dxtb/_src/calculators/types/analytical.py index 10882a609..b0db6ce3e 100644 --- a/src/dxtb/_src/calculators/types/analytical.py +++ b/src/dxtb/_src/calculators/types/analytical.py @@ -168,7 +168,7 @@ def forces_analytical( timer.stop("ograd") overlap = self.cache["overlap"] - assert isinstance(overlap, ints.types.Overlap) + assert isinstance(overlap, ints.types.OverlapIntegral) assert overlap.matrix is not None # density matrix @@ -200,7 +200,7 @@ def forces_analytical( assert self.integrals.hcore is not None cn = ncoord.cn_d3(self.numbers, positions) - dedcn, dedr = self.integrals.hcore.integral.get_gradient( + dedcn, dedr = self.integrals.hcore.get_gradient( positions, overlap.matrix, overlap_grad, @@ -410,7 +410,7 @@ def _forces_analytical( self.ihelp, self.opts.scf, intmats, - self.integrals.hcore.integral.refocc, + self.integrals.hcore.refocc, ) timer.stop("SCF") @@ -462,7 +462,7 @@ def _forces_analytical( ) cn = ncoord.cn_d3(self.numbers, positions) - dedcn, dedr = self.integrals.hcore.integral.get_gradient( + dedcn, dedr = self.integrals.hcore.get_gradient( positions, intmats.overlap, overlap_grad, @@ -524,6 +524,9 @@ def dipole_analytical( Tensor Electric dipole moment of shape `(..., 3)`. """ + # require caching for analytical calculation at end of function + kwargs["store_dipole"] = True + # run single point and check if integral is populated result = self.singlepoint(positions, chrg, spin, **kwargs) @@ -534,19 +537,23 @@ def dipole_analytical( f"be added automatically if the '{efield.LABEL_EFIELD}' " "interaction is added to the Calculator." ) - if dipint.matrix is None: + + # Use try except to raise more informative error message, because + # `dipint.matrix` already raises a RuntimeError if the matrix is None. + try: + _ = dipint.matrix + except RuntimeError as e: raise RuntimeError( "Dipole moment requires a dipole integral. They should " f"be added automatically if the '{efield.LABEL_EFIELD}' " - "interaction is added to the Calculator." - ) + "interaction is added to the Calculator. This is " + "probably a bug. Check the cache setup.\n\n" + f"Original error: {str(e)}" + ) from e # pylint: disable=import-outside-toplevel from ..properties.moments.dip import dipole - # dip = properties.dipole( - # numbers, positions, result.density, self.integrals.dipole - # ) qat = self.ihelp.reduce_orbital_to_atom(result.charges.mono) dip = dipole(qat, positions, result.density, dipint.matrix) return dip diff --git a/src/dxtb/_src/calculators/types/autograd.py b/src/dxtb/_src/calculators/types/autograd.py index fc6a1dbfb..dd36dd571 100644 --- a/src/dxtb/_src/calculators/types/autograd.py +++ b/src/dxtb/_src/calculators/types/autograd.py @@ -30,7 +30,7 @@ from dxtb import OutputHandler, timer from dxtb._src.components.interactions.field import efield as efield from dxtb._src.constants import defaults -from dxtb._src.typing import Any, Literal, Tensor +from dxtb._src.typing import Any, Callable, Literal, Tensor from ..properties.vibration import ( IRResult, @@ -155,8 +155,18 @@ def forces( # pylint: disable=import-outside-toplevel from tad_mctc.autograd import jacrev - # jacrev requires a scalar from `self.energy`! - deriv = jacrev(self.energy, argnums=0)(positions, chrg, spin, **kwargs) + try: + # jacrev requires a scalar from `self.energy`! + deriv = jacrev(self.energy, argnums=0)(positions, chrg, spin, **kwargs) + except RuntimeError as e: + if "clone is not supported by NestedIntSymNode" in str(e): + raise RuntimeError( + "This is a bug in PyTorch 2.3.0 and 2.3.1. " + "Try manually importing `torch._dynamo` before running " + "any calculation." + ) from e + raise + assert isinstance(deriv, Tensor) elif grad_mode == "row": @@ -497,19 +507,7 @@ def dipole_deriv( Tensor Cartesian dipole derivative of shape ``(..., 3, nat, 3)``. """ - - if use_analytical is True: - if not hasattr(self, "dipole_analytical") or not callable( - getattr(self, "dipole_analytical") - ): - raise ValueError( - "Analytical dipole moment not available. " - "Please use a calculator, which subclasses " - "the `AnalyticalCalculator`." - ) - dip_fcn = self.dipole_analytical # type: ignore - else: - dip_fcn = self.dipole + dip_fcn = self._get_dipole_fcn(use_analytical) if use_functorch is True: # pylint: disable=import-outside-toplevel @@ -592,20 +590,8 @@ def polarizability( # retrieve the efield interaction and the field field = self.interactions.get_interaction(efield.LABEL_EFIELD).field - if use_analytical is True: - if not hasattr(self, "dipole_analytical") or not callable( - getattr(self, "dipole_analytical") - ): - raise ValueError( - "Analytical dipole moment not available. " - "Please use a calculator, which subclasses " - "the `AnalyticalCalculator`." - ) - - # FIXME: Not working for Raman - dip_fcn = self.dipole_analytical # type: ignore - else: - dip_fcn = self.dipole + # FIXME: Not working for Raman + dip_fcn = self._get_dipole_fcn(use_analytical) if use_functorch is False: # pylint: disable=import-outside-toplevel @@ -948,6 +934,31 @@ def raman( return RamanResult(vib_res.freqs, intensities, depol) + ########################################################################## + + def _get_dipole_fcn(self, use_analytical: bool) -> Callable: + if use_analytical is False: + return self.dipole + + if not hasattr(self, "dipole_analytical"): + raise ValueError( + "Analytical dipole moment not available. " + "Please use a calculator, which subclasses " + "the `AnalyticalCalculator`." + ) + if not callable(getattr(self, "dipole_analytical")): + raise ValueError( + "Calculator an attribute `dipole_analytical` but it " + "is not callable. This should not happen and is an " + "implementation error." + ) + + self.opts.cache.store.dipole = True + + return self.dipole_analytical # type: ignore + + ########################################################################## + def calculate( self, properties: list[str], diff --git a/src/dxtb/_src/calculators/types/base.py b/src/dxtb/_src/calculators/types/base.py index 984e87822..ca2ca0cc0 100644 --- a/src/dxtb/_src/calculators/types/base.py +++ b/src/dxtb/_src/calculators/types/base.py @@ -133,9 +133,9 @@ def __init__( raman_depol: Tensor | None = None, # hcore: Tensor | None = None, - overlap: ints.types.Overlap | None = None, - dipint: ints.types.Dipole | None = None, - quadint: ints.types.Quadrupole | None = None, + overlap: ints.types.OverlapIntegral | None = None, + dipint: ints.types.DipoleIntegral | None = None, + quadint: ints.types.QuadrupoleIntegral | None = None, # bond_orders: Tensor | None = None, coefficients: Tensor | None = None, @@ -404,6 +404,12 @@ class BaseCalculator(GetPropertiesMixin, TensorLike): results: dict[str, Any] """Results container.""" + _ncalcs: int + """ + Number of calculations performed with the calculator. Helpful for keeping + track of cache hits and actual new calculations. + """ + def __init__( self, numbers: Tensor, @@ -481,6 +487,7 @@ def __init__( if self.opts.batch_mode == 0 and numbers.ndim > 1: self.opts.batch_mode = 1 + # TODO: Should the IndexHelper be a singleton? self.ihelp = IndexHelper.from_numbers(numbers, par, self.opts.batch_mode) ################ @@ -573,6 +580,7 @@ def __init__( "interaction." ) self.opts.ints.level = max(labels.INTLEVEL_DIPOLE, self.opts.ints.level) + if efield_grad.LABEL_EFIELD_GRAD in self.interactions.labels: if self.opts.ints.level < labels.INTLEVEL_DIPOLE: OutputHandler.warn( @@ -582,21 +590,29 @@ def __init__( ) self.opts.ints.level = max(labels.INTLEVEL_QUADRUPOLE, self.opts.ints.level) - # setup integral - driver = self.opts.ints.driver - self.integrals = ints.Integrals( - numbers, par, self.ihelp, driver=driver, intlevel=self.opts.ints.level, **dd - ) + # setup integral driver and integral container + mgr = ints.DriverManager(self.opts.ints.driver, **dd) + mgr.create_driver(numbers, par, self.ihelp) + + self.integrals = ints.Integrals(mgr, intlevel=self.opts.ints.level, **dd) if self.opts.ints.level >= labels.INTLEVEL_OVERLAP: - self.integrals.hcore = ints.types.HCore(numbers, par, self.ihelp, **dd) - self.integrals.overlap = ints.types.Overlap(driver=driver, **dd) + self.integrals.hcore = ints.factories.new_hcore( + numbers, par, self.ihelp, **dd + ) + self.integrals.overlap = ints.factories.new_overlap( + driver=mgr.driver_type, **dd + ) if self.opts.ints.level >= labels.INTLEVEL_DIPOLE: - self.integrals.dipole = ints.types.Dipole(driver=driver, **dd) + self.integrals.dipole = ints.factories.new_dipint( + driver=mgr.driver_type, **dd + ) if self.opts.ints.level >= labels.INTLEVEL_QUADRUPOLE: - self.integrals.quadrupole = ints.types.Quadrupole(driver=driver, **dd) + self.integrals.quadrupole = ints.factories.new_quadint( + driver=mgr.driver_type, **dd + ) OutputHandler.write_stdout("done\n", v=4) diff --git a/src/dxtb/_src/calculators/types/decorators.py b/src/dxtb/_src/calculators/types/decorators.py index 3408d9cd5..dc85be41b 100644 --- a/src/dxtb/_src/calculators/types/decorators.py +++ b/src/dxtb/_src/calculators/types/decorators.py @@ -43,7 +43,6 @@ if TYPE_CHECKING: from ..base import Calculator -del TYPE_CHECKING __all__ = [ "requires_positions_grad", diff --git a/src/dxtb/_src/calculators/types/energy.py b/src/dxtb/_src/calculators/types/energy.py index dba8e07ba..2bbce5ad9 100644 --- a/src/dxtb/_src/calculators/types/energy.py +++ b/src/dxtb/_src/calculators/types/energy.py @@ -23,15 +23,11 @@ from __future__ import annotations -import logging - import torch from tad_mctc.convert import any_to_tensor from tad_mctc.io.checks import content_checks, shape_checks -from dxtb import OutputHandler -from dxtb import integrals as ints -from dxtb import labels +from dxtb import OutputHandler, labels from dxtb._src import scf from dxtb._src.constants import defaults from dxtb._src.integral.container import IntegralMatrices @@ -197,10 +193,10 @@ def singlepoint( # Core Hamiltonian integral (requires overlap internally!) # # This should be the final integral, because the others are - # potentially calculated on CPU (libcint) even in GPU runs. + # potentially calculated on CPU (libcint), even in GPU runs. # To avoid unnecessary data transfer, the core Hamiltonian should - # be last. Internally, the overlap integral is only transfered back - # to GPU when all multipole integrals are calculated. + # be calculated last. Internally, the overlap integral is only + # transfered back to GPU when all multipole integrals are calculated. if self.opts.ints.level >= labels.INTLEVEL_HCORE: OutputHandler.write_stdout_nf(" - Core Hamiltonian ... ", v=3) timer.start("Core Hamiltonian", parent_uid="Integrals") @@ -216,7 +212,7 @@ def singlepoint( # While one can theoretically skip the core Hamiltonian, the # current implementation does not account for this case because the # reference occupation is necessary for the SCF procedure. - if self.integrals.hcore is None or self.integrals.hcore.matrix is None: + if self.integrals.hcore is None: raise NotImplementedError( "Core Hamiltonian missing. Skipping the Core Hamiltonian in " "the SCF is currently not supported. Please increase the " @@ -261,7 +257,7 @@ def singlepoint( self.ihelp, self.opts.scf, intmats, - self.integrals.hcore.integral.refocc, + self.integrals.hcore.refocc, ) timer.stop("SCF") @@ -326,14 +322,34 @@ def singlepoint( if kwargs.get("store_fock", copts.fock): self.cache["fock"] = scf_results["hamiltonian"] + if kwargs.get("store_hcore", copts.hcore): self.cache["hcore"] = self.integrals.hcore + else: + if self.integrals.hcore is not None: + if self.integrals.hcore.requires_grad is False: + self.integrals.hcore.clear() + if kwargs.get("store_overlap", copts.overlap): self.cache["overlap"] = self.integrals.overlap + else: + if self.integrals.overlap is not None: + if self.integrals.overlap.requires_grad is False: + self.integrals.overlap.clear() + if kwargs.get("store_dipole", copts.dipole): self.cache["dipint"] = self.integrals.dipole + else: + if self.integrals.dipole is not None: + if self.integrals.dipole.requires_grad is False: + self.integrals.dipole.clear() + if kwargs.get("store_quadrupole", copts.quadrupole): self.cache["quadint"] = self.integrals.quadrupole + else: + if self.integrals.quadrupole is not None: + if self.integrals.quadrupole.requires_grad is False: + self.integrals.quadrupole.clear() self._ncalcs += 1 return result @@ -403,18 +419,24 @@ def bond_orders( """ self.singlepoint(positions, chrg, spin, **kwargs) + ovlp_msg = ( + "Overlap matrix not found in cache. The overlap is not saved " + "per default. Enable saving either via the calculator options " + "(`calc.opts.cache.store.overlap = True`) or by passing the " + "`store_overlap=True` keyword argument to called method, e.g., " + "`calc.energy(positions, store_overlap=True)" + ) + overlap = self.cache["overlap"] if overlap is None: - raise RuntimeError( - "Overlap matrix not found in cache. The overlap is not saved " - "per default. Enable saving either via the calculator options " - "(`calc.opts.cache.store.overlap = True`) or by passing the " - "`store_overlap=True` keyword argument to called method, e.g., " - "`calc.energy(positions, store_overlap=True)" - ) + raise RuntimeError(ovlp_msg) + + # pylint: disable=import-outside-toplevel + from dxtb._src.integral.types import OverlapIntegral - assert isinstance(overlap, ints.types.Overlap) - assert overlap.matrix is not None + assert isinstance(overlap, OverlapIntegral) + if overlap.matrix is None: + raise RuntimeError(ovlp_msg) density = self.cache["density"] if density is None: diff --git a/src/dxtb/_src/calculators/types/numerical.py b/src/dxtb/_src/calculators/types/numerical.py index 963311832..7ba09ddb2 100644 --- a/src/dxtb/_src/calculators/types/numerical.py +++ b/src/dxtb/_src/calculators/types/numerical.py @@ -293,6 +293,7 @@ def dipole_numerical( chrg: Tensor | float | int = defaults.CHRG, spin: Tensor | float | int | None = defaults.SPIN, step_size: int | float = defaults.STEP_SIZE, + **kwargs: Any, ) -> Tensor: r""" Numerically calculate the electric dipole moment :math:`\mu`. @@ -332,11 +333,11 @@ def dipole_numerical( with OutputHandler.with_verbosity(0): field[..., i] += step_size self.interactions.update_efield(field=field) - gr = self.energy(positions, chrg, spin) + gr = self.energy(positions, chrg, spin, **kwargs) field[..., i] -= 2 * step_size self.interactions.update_efield(field=field) - gl = self.energy(positions, chrg, spin) + gl = self.energy(positions, chrg, spin, **kwargs) field[..., i] += step_size self.interactions.update_efield(field=field) @@ -359,6 +360,7 @@ def dipole_deriv_numerical( chrg: Tensor | float | int = defaults.CHRG, spin: Tensor | float | int | None = defaults.SPIN, step_size: int | float = defaults.STEP_SIZE, + **kwargs: Any, ) -> Tensor: r""" Numerically calculate cartesian dipole derivative :math:`\mu'`. @@ -411,10 +413,10 @@ def dipole_deriv_numerical( for j in range(3): with OutputHandler.with_verbosity(0): positions[..., i, j] += step_size - r = _dipfcn(positions, chrg, spin) + r = _dipfcn(positions, chrg, spin, **kwargs) positions[..., i, j] -= 2 * step_size - l = _dipfcn(positions, chrg, spin) + l = _dipfcn(positions, chrg, spin, **kwargs) positions[..., i, j] += step_size deriv[..., :, i, j] = 0.5 * (r - l) / step_size @@ -438,6 +440,7 @@ def polarizability_numerical( chrg: Tensor | float | int = defaults.CHRG, spin: Tensor | float | int | None = defaults.SPIN, step_size: int | float = defaults.STEP_SIZE, + **kwargs: Any, ) -> Tensor: r""" Numerically calculate the polarizability tensor :math:`\alpha`. @@ -489,11 +492,11 @@ def polarizability_numerical( with OutputHandler.with_verbosity(0): field[..., i] += step_size self.interactions.update_efield(field=field) - gr = _dipfcn(positions, chrg, spin) + gr = _dipfcn(positions, chrg, spin, **kwargs) field[..., i] -= 2 * step_size self.interactions.update_efield(field=field) - gl = _dipfcn(positions, chrg, spin) + gl = _dipfcn(positions, chrg, spin, **kwargs) field[..., i] += step_size self.interactions.update_efield(field=field) diff --git a/src/dxtb/_src/cli/driver.py b/src/dxtb/_src/cli/driver.py index f7e5e4708..fa102ea48 100644 --- a/src/dxtb/_src/cli/driver.py +++ b/src/dxtb/_src/cli/driver.py @@ -36,6 +36,7 @@ from dxtb._src.components.interactions.field import new_efield from dxtb._src.constants import labels from dxtb._src.timing import timer +from dxtb._src.typing import Tensor __all__ = ["Driver"] @@ -118,7 +119,7 @@ def _set_attr(self, attr: str) -> int | list[int]: return vals - def singlepoint(self) -> Result | None: + def singlepoint(self) -> Result | Tensor: timer.start("Setup") args = self.args @@ -249,47 +250,47 @@ def singlepoint(self) -> Result | None: if args.forces is True: positions.requires_grad_(True) - forces = calc.forces(positions, chrg) + result = calc.forces(positions, chrg) calc.reset() timer.print() - print_grad(forces.clone(), numbers) + print_grad(result.clone(), numbers) # io.OutputHandler.dump_warnings() - return + return result if args.forces_numerical is True: timer.start("Forces") - forces = calc.forces_numerical(positions, chrg) + result = calc.forces_numerical(positions, chrg) timer.stop("Forces") calc.reset() - print_grad(forces.clone(), numbers) + print_grad(result.clone(), numbers) # io.OutputHandler.dump_warnings() - return + return result if args.hessian is True: positions.requires_grad_(True) timer.start("Hessian") - hessian = calc.hessian(positions, chrg) + result = calc.hessian(positions, chrg) timer.stop("Hessian") calc.reset() - print(hessian.clone().detach()) + print(result.clone().detach()) # io.OutputHandler.dump_warnings() - return + return result if args.hessian_numerical is True: positions.requires_grad_(True) timer.start("Hessian") - hessian = calc.hessian_numerical(positions, chrg) + result = calc.hessian_numerical(positions, chrg) timer.stop("Hessian") calc.reset() - print(hessian.clone()) + print(result.clone()) # io.OutputHandler.dump_warnings() - return + return result if args.ir is True: # TODO: Better handling here @@ -360,6 +361,8 @@ def singlepoint(self) -> Result | None: io.OutputHandler.dump_warnings() return result + raise RuntimeError("No calculation was performed.") + def __repr__(self) -> str: # pragma: no cover """Custom print representation of class.""" return f"{self.__class__.__name__}({self.args})" diff --git a/src/dxtb/_src/components/classicals/dispersion/__init__.py b/src/dxtb/_src/components/classicals/dispersion/__init__.py index 5dff1b1f0..70bafb47d 100644 --- a/src/dxtb/_src/components/classicals/dispersion/__init__.py +++ b/src/dxtb/_src/components/classicals/dispersion/__init__.py @@ -18,7 +18,6 @@ Dispersion models in the extended tight-binding model. """ -from .base import Dispersion from .d3 import DispersionD3 from .d4 import DispersionD4 from .factory import new_dispersion diff --git a/src/dxtb/_src/components/classicals/dispersion/d3.py b/src/dxtb/_src/components/classicals/dispersion/d3.py index 5cc7a5ea2..f6c4254ea 100644 --- a/src/dxtb/_src/components/classicals/dispersion/d3.py +++ b/src/dxtb/_src/components/classicals/dispersion/d3.py @@ -28,7 +28,8 @@ from dxtb._src.typing import Any, CountingFunction, Tensor -from .base import ClassicalCache, Dispersion +from ..base import ClassicalCache +from .base import Dispersion __all__ = ["DispersionD3", "DispersionD3Cache"] diff --git a/src/dxtb/_src/components/classicals/dispersion/d4.py b/src/dxtb/_src/components/classicals/dispersion/d4.py index 8f7808ed8..3594ead21 100644 --- a/src/dxtb/_src/components/classicals/dispersion/d4.py +++ b/src/dxtb/_src/components/classicals/dispersion/d4.py @@ -27,7 +27,8 @@ from dxtb._src.typing import Any, Tensor -from .base import ClassicalCache, Dispersion +from ..base import ClassicalCache +from .base import Dispersion __all__ = ["DispersionD4", "DispersionD4Cache"] diff --git a/src/dxtb/_src/components/classicals/halogen/__init__.py b/src/dxtb/_src/components/classicals/halogen/__init__.py index 32d4f3e18..8183ee62f 100644 --- a/src/dxtb/_src/components/classicals/halogen/__init__.py +++ b/src/dxtb/_src/components/classicals/halogen/__init__.py @@ -56,5 +56,7 @@ print(energy.sum(-1)) # tensor(0.0025) """ -from .factory import * -from .hal import * +from .factory import new_halogen +from .hal import LABEL_HALOGEN, Halogen + +__all__ = ["new_halogen", "Halogen", "LABEL_HALOGEN"] diff --git a/src/dxtb/_src/components/interactions/coulomb/secondorder.py b/src/dxtb/_src/components/interactions/coulomb/secondorder.py index d638e1b40..5f8c46df3 100644 --- a/src/dxtb/_src/components/interactions/coulomb/secondorder.py +++ b/src/dxtb/_src/components/interactions/coulomb/secondorder.py @@ -32,11 +32,11 @@ # Define atomic numbers, positions, and charges numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) q = torch.tensor([ diff --git a/src/dxtb/_src/components/interactions/coulomb/thirdorder.py b/src/dxtb/_src/components/interactions/coulomb/thirdorder.py index c10a3738c..483d0e8d6 100644 --- a/src/dxtb/_src/components/interactions/coulomb/thirdorder.py +++ b/src/dxtb/_src/components/interactions/coulomb/thirdorder.py @@ -33,11 +33,11 @@ # Define atomic numbers and their positions numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) # Atomic charges diff --git a/src/dxtb/_src/components/interactions/solvation/alpb.py b/src/dxtb/_src/components/interactions/solvation/alpb.py index f92e45899..7682f2d8b 100644 --- a/src/dxtb/_src/components/interactions/solvation/alpb.py +++ b/src/dxtb/_src/components/interactions/solvation/alpb.py @@ -29,11 +29,11 @@ numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) charges = torch.tensor([ -8.41282505804719e-2, diff --git a/src/dxtb/_src/components/interactions/solvation/born.py b/src/dxtb/_src/components/interactions/solvation/born.py index 044bc0ede..c5326f3db 100644 --- a/src/dxtb/_src/components/interactions/solvation/born.py +++ b/src/dxtb/_src/components/interactions/solvation/born.py @@ -31,11 +31,11 @@ # Define atomic numbers and positions of the atoms numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) # Calculate the Born radii for the given atomic configuration diff --git a/src/dxtb/_src/constants/labels/method.py b/src/dxtb/_src/constants/labels/method.py index 1109a5c02..087227401 100644 --- a/src/dxtb/_src/constants/labels/method.py +++ b/src/dxtb/_src/constants/labels/method.py @@ -19,6 +19,16 @@ =============== """ +__all__ = [ + "GFN0_XTB", + "GFN0_XTB_STRS", + "GFN1_XTB", + "GFN1_XTB_STRS", + "GFN2_XTB", + "GFN2_XTB_STRS", + "GFN_XTB_MAP", +] + # xtb GFN0_XTB = 0 """Integer code for GFN0-xTB.""" diff --git a/src/dxtb/_src/constants/labels/scf.py b/src/dxtb/_src/constants/labels/scf.py index 97cc70813..c973296ff 100644 --- a/src/dxtb/_src/constants/labels/scf.py +++ b/src/dxtb/_src/constants/labels/scf.py @@ -21,6 +21,42 @@ Labels for SCF-related options. """ +__all__ = [ + "GUESS_EEQ", + "GUESS_EEQ_STRS", + "GUESS_SAD", + "GUESS_SAD_STRS", + "GUESS_MAP", + "SCF_MODE_FULL", + "SCF_MODE_FULL_STRS", + "SCF_MODE_IMPLICIT", + "SCF_MODE_IMPLICIT_STRS", + "SCF_MODE_IMPLICIT_NON_PURE", + "SCF_MODE_IMPLICIT_NON_PURE_STRS", + "SCF_MODE_EXPERIMENTAL", + "SCF_MODE_EXPERIMENTAL_STRS", + "SCF_MODE_MAP", + "SCP_MODE_FOCK", + "SCP_MODE_FOCK_STRS", + "SCP_MODE_CHARGE", + "SCP_MODE_CHARGE_STRS", + "SCP_MODE_POTENTIAL", + "SCP_MODE_POTENTIAL_STRS", + "SCP_MODE_MAP", + "FERMI_PARTITION_EQUAL", + "FERMI_PARTITION_EQUAL_STRS", + "FERMI_PARTITION_ATOMIC", + "FERMI_PARTITION_ATOMIC_STRS", + "FERMI_PARTITION_MAP", + "MIXER_LINEAR", + "MIXER_LINEAR_STRS", + "MIXER_ANDERSON", + "MIXER_ANDERSON_STRS", + "MIXER_BROYDEN", + "MIXER_BROYDEN_STRS", + "MIXER_MAP", +] + # guess GUESS_EEQ = 0 """Integer code for EEQ guess.""" diff --git a/src/dxtb/_src/exlibs/xitorch/_core/editable_module.py b/src/dxtb/_src/exlibs/xitorch/_core/editable_module.py index d67a2f79f..cd98ae827 100644 --- a/src/dxtb/_src/exlibs/xitorch/_core/editable_module.py +++ b/src/dxtb/_src/exlibs/xitorch/_core/editable_module.py @@ -181,16 +181,13 @@ def _get_unique_params_idxs( id_param = id(param) # search the id if it has been added to the list - try: + if id_param in ids: jfound = ids.index(id_param) idx_map[jfound].append(i) - continue - except ValueError: - pass - - ids.append(id_param) - idxs.append(i) - idx_map.append([i]) + else: + ids.append(id_param) + idxs.append(i) + idx_map.append([i]) self._number_of_params[methodname] = len(allparams) self._unique_params_idxs[methodname] = idxs diff --git a/src/dxtb/_src/exlibs/xitorch/_utils/decorators.py b/src/dxtb/_src/exlibs/xitorch/_utils/decorators.py deleted file mode 100644 index b096f040f..000000000 --- a/src/dxtb/_src/exlibs/xitorch/_utils/decorators.py +++ /dev/null @@ -1,56 +0,0 @@ -# This file is part of dxtb, modified from xitorch/xitorch. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Original file licensed under the MIT License by xitorch/xitorch. -# Modifications made by Grimme Group. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import inspect -import warnings - -__all__ = ["deprecated"] - - -def deprecated(date_str): - return lambda obj: _deprecated(obj, date_str) - - -def _deprecated(obj, date_str): - if inspect.isfunction(obj): - name = "Function %s" % (obj.__str__()) - elif inspect.isclass(obj): - name = "Class %s" % (obj.__name__) - - if inspect.ismethod(obj) or inspect.isfunction(obj): - - @functools.wraps(obj) - def fcn(*args, **kwargs): - warnings.warn(f"{name} is deprecated since {date_str}", stacklevel=2) - return obj(*args, **kwargs) - - return fcn - - elif inspect.isclass(obj): - # replace the __init__ function - old_init = obj.__init__ - - @functools.wraps(old_init) - def newinit(*args, **kwargs): - warnings.warn(f"{name} is deprecated since {date_str}", stacklevel=2) - return old_init(*args, **kwargs) - - obj.__init__ = newinit - return obj diff --git a/src/dxtb/_src/integral/abc.py b/src/dxtb/_src/integral/abc.py new file mode 100644 index 000000000..adb407f67 --- /dev/null +++ b/src/dxtb/_src/integral/abc.py @@ -0,0 +1,85 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integrals: Abstract Base Classes +================================ + +Abstract case class for integrals. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from dxtb._src.typing import TYPE_CHECKING, Any, Tensor + +if TYPE_CHECKING: + from dxtb._src.integral.base import IntDriver + + +__all__ = ["IntegralABC"] + + +class IntegralABC(ABC): + """ + Abstract base class for integral implementations. + + All integral calculations are executed by this class. + """ + + @abstractmethod + def build(self, driver: IntDriver, **kwargs: Any) -> Tensor: + """ + Create the integral matrix. + + Parameters + ---------- + driver : IntDriver + Integral driver for the calculation. + + Returns + ------- + Tensor + Integral matrix. + """ + + @abstractmethod + def get_gradient(self, driver: IntDriver, **kwargs: Any) -> Tensor: + """ + Calculate the full nuclear gradient matrix of the integral. + + Parameters + ---------- + driver : IntDriver + Integral driver for the calculation. + + Returns + ------- + Tensor + Nuclear integral derivative matrix. + """ + + @abstractmethod + def normalize(self, norm: Tensor | None = None, **kwargs: Any) -> None: + """ + Normalize the integral (changes ``self.matrix``). + + Parameters + ---------- + norm : Tensor, optional + Overlap norm to normalize the integral. + """ diff --git a/src/dxtb/_src/integral/base.py b/src/dxtb/_src/integral/base.py index e7d22caac..1290935c4 100644 --- a/src/dxtb/_src/integral/base.py +++ b/src/dxtb/_src/integral/base.py @@ -18,25 +18,25 @@ Integrals: Base Classes ======================= -Base class for Integrals classes and their actual implementations. +Base class for integral classes and their actual implementations. """ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod import torch +from tad_mctc.math import einsum from dxtb import IndexHelper from dxtb._src.basis.bas import Basis from dxtb._src.param import Param -from dxtb._src.typing import PathLike, Tensor, TensorLike +from dxtb._src.typing import Literal, PathLike, Self, Tensor, TensorLike -__all__ = [ - "BaseIntegralImplementation", - "IntDriver", - "IntegralContainer", -] +from .abc import IntegralABC +from .utils import snorm + +__all__ = ["BaseIntegral", "IntDriver"] class IntDriver(TensorLike): @@ -51,6 +51,9 @@ class IntDriver(TensorLike): ihelp: IndexHelper """Helper class for indexing.""" + family: Literal["PyTorch", "libcint"] + """Label for integral implementation family.""" + __label: str """Identifier label for integral driver.""" @@ -181,15 +184,12 @@ def setup(self, positions: Tensor, **kwargs) -> None: """ def __str__(self) -> str: - dict_repr = [] - for key, value in self.__dict__.items(): - if isinstance(value, Tensor): - value_repr = f"{value.shape}" - else: - value_repr = repr(value) - dict_repr.append(f" {key}: {value_repr}") - dict_str = "{\n" + ",\n".join(dict_repr) + "\n}" - return f"{self.__class__.__name__}({dict_str})" + return ( + f"{self.__class__.__name__}(" + f"Family: {self.family}, " + f"Number of Atoms: {self.numbers.shape[-1]}, " + f"Setup?: {self.is_setup()})" + ) def __repr__(self) -> str: return str(self) @@ -198,70 +198,40 @@ def __repr__(self) -> str: ######################################################### -class IntegralImplementationABC(ABC): +class BaseIntegral(IntegralABC, TensorLike): """ - Abstract base class for (actual) integral implementations. + Base class for integral implementations. - All integral calculations are executed by this class. + All integral calculations are executed by its child classes. """ - @abstractmethod - def build(self, driver: IntDriver) -> Tensor: - """ - Create the integral matrix. - - Parameters - ---------- - driver : IntDriver - Integral driver for the calculation. + _matrix: Tensor | None + """Internal storage variable for the integral matrix.""" - Returns - ------- - Tensor - Integral matrix. - """ + _gradient: Tensor | None + """Internal storage variable for the cartesian gradient.""" - @abstractmethod - def get_gradient(self, driver: IntDriver) -> Tensor: - """ - Create the nuclear integral derivative matrix. - - Parameters - ---------- - driver : IntDriver - Integral driver for the calculation. - - Returns - ------- - Tensor - Nuclear integral derivative matrix. - """ - - -class BaseIntegralImplementation(IntegralImplementationABC, TensorLike): - """ - Base class for (actual) integral implementations. + _norm: Tensor | None + """Internal storage variable for the overlap norm.""" - All integral calculations are executed by this class. - """ + family: str | None + """Family of the integral implementation (PyTorch or libcint).""" - __slots__ = ["_matrix", "_norm", "_gradient"] + __slots__ = ["_matrix", "_gradient", "_norm"] def __init__( self, device: torch.device | None = None, dtype: torch.dtype | None = None, - normalize: bool = True, _matrix: Tensor | None = None, - _norm: Tensor | None = None, _gradient: Tensor | None = None, + _norm: Tensor | None = None, ) -> None: super().__init__(device=device, dtype=dtype) self.label = self.__class__.__name__ - self.normalize = normalize - self._matrix = _matrix self._norm = _norm + self._matrix = _matrix self._gradient = _gradient def checks(self, driver: IntDriver) -> None: @@ -279,6 +249,86 @@ def checks(self, driver: IntDriver) -> None: "before passing the driver to the integral build." ) + if "pytorch" in self.label.casefold(): + # pylint: disable=import-outside-toplevel + from .driver.pytorch.driver import BaseIntDriverPytorch as _BaseIntDriver + + elif "libcint" in self.label.casefold(): + # pylint: disable=import-outside-toplevel + from .driver.libcint.driver import BaseIntDriverLibcint as _BaseIntDriver + + else: + raise RuntimeError(f"Unknown integral implementation: '{self.label}'.") + + if not isinstance(driver, _BaseIntDriver): + raise RuntimeError(f"Wrong integral driver selected for '{self.label}'.") + + def clear(self) -> None: + """ + Clear the integral matrix and gradient. + """ + self._matrix = None + self._norm = None + self._gradient = None + + @property + def requires_grad(self) -> bool: + """ + Check if any field of the integral class is requires gradient. + + Returns + ------- + bool + Flag for gradient requirement. + """ + for field in (self._matrix, self._gradient, self._norm): + if field is not None and field.requires_grad: + return True + + return False + + def normalize(self, norm: Tensor | None = None) -> None: + """ + Normalize the integral (changes ``self.matrix``). + + Parameters + ---------- + norm : Tensor + Overlap norm to normalize the integral. + """ + if norm is None: + if self.norm is not None: + norm = self.norm + else: + norm = snorm(self.matrix) + + if norm.ndim == 1: + einsum_str = "...ij,i,j->...ij" + elif norm.ndim == 2: + einsum_str = "b...ij,bi,bj->b...ij" + else: + raise ValueError(f"Invalid norm shape: {norm.shape}") + + self.matrix = einsum(einsum_str, self.matrix, norm, norm) + + def normalize_gradient(self, norm: Tensor | None = None) -> None: + """ + Normalize the gradient (changes ``self.gradient``). + + Parameters + ---------- + norm : Tensor + Overlap norm to normalize the integral. + """ + if norm is None: + if self.norm is not None: + norm = self.norm + else: + norm = snorm(self.matrix) + + einsum_str = "...ijx,...i,...j->...ijx" + self.gradient = einsum(einsum_str, self.gradient, norm, norm) + def to_pt(self, path: PathLike | None = None) -> None: """ Save the integral matrix to a file. @@ -294,10 +344,45 @@ def to_pt(self, path: PathLike | None = None) -> None: torch.save(self.matrix, path) + def to(self, device: torch.device) -> Self: + """ + Returns a copy of the integral on the specified device "``device``". + + This is essentially a wrapper around the :meth:`to` method of the + :class:`TensorLike` class, but explicitly also moves the integral + matrix. + + Parameters + ---------- + device : torch.device + Device to which all associated tensors should be moved. + + Returns + ------- + Self + A copy of the integral placed on the specified device. + """ + if self._gradient is not None: + self._gradient = self._gradient.to(device=device) + + if self._norm is not None: + self._norm = self._norm.to(device=device) + + if self._matrix is not None: + self._matrix = self._matrix.to(device=device) + + return super().to(device=device) + @property def matrix(self) -> Tensor: if self._matrix is None: - raise RuntimeError("Integral matrix has not been calculated.") + raise RuntimeError( + "Integral matrix not found. This can be caused by two " + "reasons:\n" + "1. The integral has not been calculated yet.\n" + "2. The integral was cleared, despite being required " + "in a subsequent calculation. Check the cache settings." + ) return self._matrix @matrix.setter @@ -305,9 +390,7 @@ def matrix(self, mat: Tensor) -> None: self._matrix = mat @property - def norm(self) -> Tensor: - if self._norm is None: - raise RuntimeError("Overlap norm has not been calculated.") + def norm(self) -> Tensor | None: return self._norm @norm.setter @@ -328,47 +411,12 @@ def __str__(self) -> str: d = self.__dict__.copy() if self._matrix is not None: d["_matrix"] = self._matrix.shape - if self._norm is not None: - d["_norm"] = self._norm.shape if self._gradient is not None: d["_gradient"] = self._gradient.shape + if self._norm is not None: + d["_norm"] = self._norm.shape return f"{self.__class__.__name__}({d})" def __repr__(self) -> str: return str(self) - - -######################################################### - - -class IntegralContainer(TensorLike): - """ - Base class for integral container. - """ - - def __init__( - self, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - _run_checks: bool = True, - ): - super().__init__(device, dtype) - self._run_checks = _run_checks - - @property - def run_checks(self) -> bool: - return self._run_checks - - @run_checks.setter - def run_checks(self, run_checks: bool) -> None: - current = self.run_checks - self._run_checks = run_checks - - # switching from False to True should automatically run checks - if current is False and run_checks is True: - self.checks() - - @abstractmethod - def checks(self) -> None: - """Run checks for integrals.""" diff --git a/src/dxtb/_src/integral/container.py b/src/dxtb/_src/integral/container.py index fde186218..6bff2aa94 100644 --- a/src/dxtb/_src/integral/container.py +++ b/src/dxtb/_src/integral/container.py @@ -24,235 +24,240 @@ from __future__ import annotations import logging +from abc import abstractmethod import torch -from dxtb import IndexHelper, labels +from dxtb import labels from dxtb._src.constants import defaults, labels -from dxtb._src.param import Param -from dxtb._src.typing import Any, Tensor +from dxtb._src.typing import Any, Tensor, TensorLike +from dxtb._src.xtb.base import BaseHamiltonian -from .base import IntDriver, IntegralContainer -from .types import Dipole, HCore, Overlap, Quadrupole +from .base import BaseIntegral +from .driver import DriverManager +from .types import DipoleIntegral, OverlapIntegral, QuadrupoleIntegral __all__ = ["Integrals", "IntegralMatrices"] logger = logging.getLogger(__name__) +class IntegralContainer(TensorLike): + """ + Base class for integral container. + """ + + def __init__( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + _run_checks: bool = True, + ): + super().__init__(device, dtype) + self._run_checks = _run_checks + + @property + def run_checks(self) -> bool: + return self._run_checks + + @run_checks.setter + def run_checks(self, run_checks: bool) -> None: + current = self.run_checks + self._run_checks = run_checks + + # switching from False to True should automatically run checks + if current is False and run_checks is True: + self.checks() + + @abstractmethod + def checks(self) -> None: + """Run checks for integrals.""" + + class Integrals(IntegralContainer): """ Integral container. """ __slots__ = [ - "numbers", - "par", - "ihelp", + "mgr", + "intlevel", "_hcore", "_overlap", "_dipole", "_quadrupole", - "_driver", ] def __init__( self, - numbers: Tensor, - par: Param, - ihelp: IndexHelper, + mgr: DriverManager, *, - driver: int = labels.INTDRIVER_LIBCINT, intlevel: int = defaults.INTLEVEL, device: torch.device | None = None, dtype: torch.dtype | None = None, - _hcore: HCore | None = None, - _overlap: Overlap | None = None, - _dipole: Dipole | None = None, - _quadrupole: Quadrupole | None = None, - **kwargs: Any, + _hcore: BaseHamiltonian | None = None, + _overlap: OverlapIntegral | None = None, + _dipole: DipoleIntegral | None = None, + _quadrupole: QuadrupoleIntegral | None = None, ) -> None: super().__init__(device, dtype) - self.numbers = numbers - self.par = par + self.mgr = mgr + self.intlevel = intlevel + self._hcore = _hcore self._overlap = _overlap self._dipole = _dipole self._quadrupole = _quadrupole - self._intlevel = intlevel - - # per default, libcint is run on the CPU - self.force_cpu_for_libcint = kwargs.pop( - "force_cpu_for_libcint", - True if driver == labels.INTDRIVER_LIBCINT else False, - ) - - # Determine which driver class to instantiate - if driver == labels.INTDRIVER_LIBCINT: - # pylint: disable=import-outside-toplevel - from .driver.libcint import IntDriverLibcint - - if self.force_cpu_for_libcint is True: - device = torch.device("cpu") - numbers = numbers.to(device=device) - ihelp = ihelp.to(device=device) - - self._driver = IntDriverLibcint( - numbers, par, ihelp, device=device, dtype=dtype - ) - elif driver == labels.INTDRIVER_ANALYTICAL: - # pylint: disable=import-outside-toplevel - from .driver.pytorch import IntDriverPytorch - - self._driver = IntDriverPytorch( - numbers, par, ihelp, device=device, dtype=dtype - ) - elif driver == labels.INTDRIVER_AUTOGRAD: - # pylint: disable=import-outside-toplevel - from .driver.pytorch import IntDriverPytorchNoAnalytical - - self._driver = IntDriverPytorchNoAnalytical( - numbers, par, ihelp, device=device, dtype=dtype - ) - else: - raise ValueError(f"Unknown integral driver '{driver}'.") - - # potentially moved to CPU - self.ihelp = ihelp - - # Integral driver - - @property - def driver(self) -> IntDriver: - if self._driver is None: - raise ValueError("No integral driver provided.") - return self._driver - - @driver.setter - def driver(self, driver: IntDriver) -> None: - self._driver = driver def setup_driver(self, positions: Tensor, **kwargs: Any) -> None: - logger.debug("Integral Driver: Start setup.") - - if self.force_cpu_for_libcint is True: - positions = positions.to(device=torch.device("cpu")) - - if self.driver.is_latest(positions) is True: - logger.debug("Integral Driver: Skip setup. Already done.") - return - - self.driver.setup(positions, **kwargs) - logger.debug("Integral Driver: Finished setup.") - - def invalidate_driver(self) -> None: - """Invalidate the integral driver to require new setup.""" - self.driver.invalidate() + self.mgr.setup_driver(positions, **kwargs) # Core Hamiltonian @property - def hcore(self) -> HCore | None: + def hcore(self) -> BaseHamiltonian | None: return self._hcore @hcore.setter - def hcore(self, hcore: HCore) -> None: + def hcore(self, hcore: BaseHamiltonian) -> None: self._hcore = hcore self.checks() - # TODO: Allow Hamiltonian build without overlap - def build_hcore(self, positions: Tensor, **kwargs) -> Tensor: + def build_hcore( + self, positions: Tensor, with_overlap: bool = True, **kwargs + ) -> Tensor: logger.debug("Core Hamiltonian: Start building matrix.") if self.hcore is None: raise RuntimeError("Core Hamiltonian integral not initialized.") - if self.overlap is None: - raise RuntimeError("Overlap integral not initialized.") - - # overlap integral required - ovlp = self.overlap.integral - if ovlp.matrix is None: + # if overlap integral required + if with_overlap is True and self.overlap is None: self.build_overlap(positions, **kwargs) - cn = kwargs.pop("cn", None) - if cn is None: - # pylint: disable=import-outside-toplevel - from ..ncoord import cn_d3 - - cn = cn_d3(self.numbers, positions) + if with_overlap is True: + assert self.overlap is not None + overlap = self.overlap.matrix + else: + overlap = None - hcore = self.hcore.integral.build(positions, ovlp.matrix, cn=cn) + hcore = self.hcore.build(positions, overlap=overlap) logger.debug("Core Hamiltonian: All finished.") return hcore # overlap @property - def overlap(self) -> Overlap | None: + def overlap(self) -> OverlapIntegral | None: + """ + Overlap integral class. The integral matrix of shape + ``(..., nao, nao)`` is stored in the :attr:`matrix` attribute. + + Returns + ------- + Tensor | None + Overlap integral if set, else ``None``. + """ return self._overlap @overlap.setter - def overlap(self, overlap: Overlap) -> None: + def overlap(self, overlap: OverlapIntegral) -> None: self._overlap = overlap self.checks() def build_overlap(self, positions: Tensor, **kwargs: Any) -> Tensor: # in case CPU is forced for libcint, move positions to CPU - if self.force_cpu_for_libcint is True: - positions = positions.to(device=torch.device("cpu")) + if self.mgr.force_cpu_for_libcint is True: + if positions.device != torch.device("cpu"): + positions = positions.to(device=torch.device("cpu")) - self.setup_driver(positions, **kwargs) + self.mgr.setup_driver(positions, **kwargs) logger.debug("Overlap integral: Start building matrix.") if self.overlap is None: - raise RuntimeError("No overlap integral provided.") + # pylint: disable=import-outside-toplevel + from .factory import new_overlap + + self.overlap = new_overlap( + self.mgr.driver_type, + **self.dd, + **kwargs, + ) - self.overlap.build(self.driver) + # DEVNOTE: If the overlap is only built if `self.overlap._matrix` is + # `None`, the overlap will not be rebuilt if the positions change, + # i.e., when the driver was invalidated. Hence, we would require a + # full reset of the integrals via `reset_all`. However, the integral + # reset cannot be trigger by the driver manager, so we cannot add this + # check here. If we do, the hessian tests will fail as the overlap is + # not recalculated for positions + delta. + self.overlap.build(self.mgr.driver) + self.overlap.normalize() assert self.overlap.matrix is not None # move integral to the correct device... - if self.force_cpu_for_libcint is True: - # ... but only if no other multipole integrals are required - if self._intlevel <= labels.INTLEVEL_HCORE: - self.overlap.integral = self.overlap.integral.to(device=self.device) - - # FIXME: The matrix has to be moved explicitly, because when - # singlepoint is called a second time, the integral is already - # on the correct device (from the to of the first call) and the - # matrix is not moved because the to method exits immediately. - # This is a workaround and can possibly be fixed when the - # matrices are no longer stored (should only return in sp) - self.overlap.integral.matrix = self.overlap.integral.matrix.to( - device=self.device - ) + if self.mgr.force_cpu_for_libcint is True: + # ...but only if no other multipole integrals are required + if self.intlevel <= labels.INTLEVEL_HCORE: + self.overlap = self.overlap.to(device=self.device) + + # DEVNOTE: This is a sanity check to avoid the following + # scenario: When the overlap is built on CPU (forced by + # libcint), it will be moved to the correct device after + # the last integral is built. Now, in case of a second + # call with an invalid cache, the overlap class already + # is on the correct device, but the matrix is not. Hence, + # the `to` method must be called on the matrix as well, + # which is handled in the custom `to` method of all + # integrals. + # Also make sure to pass the `force_cpu_for_libcint` + # flag when instantiating the integral classes. + assert self.overlap is not None + if self.overlap.device != self.overlap.matrix.device: + raise RuntimeError( + f"Device of '{self.overlap.label}' integral class " + f"({self.overlap.device}) and its matrix " + f"({self.overlap.matrix.device}) do not match." + ) logger.debug("Overlap integral: All finished.") + return self.overlap.matrix def grad_overlap(self, positions: Tensor, **kwargs) -> Tensor: # in case CPU is forced for libcint, move positions to CPU - if self.force_cpu_for_libcint is True: - positions = positions.to(device=torch.device("cpu")) + if self.mgr.force_cpu_for_libcint is True: + if positions.device != torch.device("cpu"): + positions = positions.to(device=torch.device("cpu")) - self.setup_driver(positions, **kwargs) + self.mgr.setup_driver(positions, **kwargs) if self.overlap is None: - raise RuntimeError("No overlap integral provided.") + # pylint: disable=import-outside-toplevel + from .factory import new_overlap + + self.overlap = new_overlap( + self.mgr.driver_type, + **self.dd, + **kwargs, + ) logger.debug("Overlap gradient: Start.") - grad = self.overlap.get_gradient(self.driver, **kwargs) + self.overlap.get_gradient(self.mgr.driver, **kwargs) + self.overlap.gradient = self.overlap.gradient.to(self.device) + self.overlap.normalize_gradient() logger.debug("Overlap gradient: All finished.") - return grad.to(self.device) + return self.overlap.gradient.to(self.device) # dipole @property - def dipole(self) -> Dipole | None: + def dipole(self) -> DipoleIntegral | None: """ - Dipole integral of shape (3, nao, nao). + Dipole integral class. The integral matrix of shape + ``(..., 3, nao, nao)``is stored in the :attr:`matrix` attribute. Returns ------- @@ -262,35 +267,40 @@ def dipole(self) -> Dipole | None: return self._dipole @dipole.setter - def dipole(self, dipole: Dipole) -> None: + def dipole(self, dipole: DipoleIntegral) -> None: self._dipole = dipole self.checks() def build_dipole(self, positions: Tensor, shift: bool = True, **kwargs: Any): # in case CPU is forced for libcint, move positions to CPU - if self.force_cpu_for_libcint: - positions = positions.to(device=torch.device("cpu")) + if self.mgr.force_cpu_for_libcint: + if positions.device != torch.device("cpu"): + positions = positions.to(device=torch.device("cpu")) - self.setup_driver(positions, **kwargs) + self.mgr.setup_driver(positions, **kwargs) logger.debug("Dipole integral: Start building matrix.") - if self.overlap is None: - raise RuntimeError("Overlap integral not initialized.") - if self.dipole is None: - raise RuntimeError("Dipole integral not initialized.") + # pylint: disable=import-outside-toplevel + from .factory import new_dipint + + self.dipole = new_dipint(self.mgr.driver_type, **self.dd, **kwargs) + + if self.overlap is None: + self.build_overlap(positions, **kwargs) + assert self.overlap is not None # build (with overlap norm) - self.dipole.integral.norm = self._norm(positions) - self.dipole.build(self.driver) + self.dipole.build(self.mgr.driver) + self.dipole.normalize(self.overlap.norm) logger.debug("Dipole integral: Finished building matrix.") # shift to rj (requires overlap integral) if shift is True: logger.debug("Dipole integral: Start shifting operator (r0->rj).") - self.dipole.integral.shift_r0_rj( - self.overlap.integral.matrix, - self.ihelp.spread_atom_to_orbital( + self.dipole.shift_r0_rj( + self.overlap.matrix, + self.mgr.driver.ihelp.spread_atom_to_orbital( positions, dim=-2, extra=True, @@ -300,26 +310,24 @@ def build_dipole(self, positions: Tensor, shift: bool = True, **kwargs: Any): # move integral to the correct device, but only if no other multipole # integrals are required - if self.force_cpu_for_libcint and self._intlevel <= labels.INTLEVEL_DIPOLE: - self.dipole.integral = self.dipole.integral.to(device=self.device) - self.dipole.integral.matrix = self.dipole.integral.matrix.to( - device=self.device - ) - - self.overlap.integral = self.overlap.integral.to(device=self.device) - self.overlap.integral.matrix = self.overlap.integral.matrix.to( - device=self.device - ) + if ( + self.mgr.force_cpu_for_libcint is True + and self.intlevel <= labels.INTLEVEL_DIPOLE + ): + self.dipole = self.dipole.to(device=self.device) + self.overlap = self.overlap.to(device=self.device) logger.debug("Dipole integral: All finished.") - return self.dipole.integral.matrix + return self.dipole.matrix # quadrupole @property - def quadrupole(self) -> Quadrupole | None: + def quadrupole(self) -> QuadrupoleIntegral | None: """ - Quadrupole integral of shape (6/9, nao, nao). + Quadrupole integral class. The integral matrix of shape + ``(..., 6, nao, nao)`` or ``(..., 9, nao, nao)`` is stored + in the :attr:`matrix` attribute. Returns ------- @@ -329,7 +337,7 @@ def quadrupole(self) -> Quadrupole | None: return self._quadrupole @quadrupole.setter - def quadrupole(self, quadrupole: Quadrupole) -> None: + def quadrupole(self, quadrupole: QuadrupoleIntegral) -> None: self._quadrupole = quadrupole self.checks() @@ -341,44 +349,56 @@ def build_quadrupole( **kwargs: Any, ): # in case CPU is forced for libcint, move positions to CPU - if self.force_cpu_for_libcint: - positions = positions.to(device=torch.device("cpu")) + if self.mgr.force_cpu_for_libcint: + if positions.device != torch.device("cpu"): + positions = positions.to(device=torch.device("cpu")) # check all instantiations - self.setup_driver(positions, **kwargs) + self.mgr.setup_driver(positions, **kwargs) logger.debug("Quad integral: Start building matrix.") - if self.overlap is None: - raise RuntimeError("Overlap integral not initialized.") - if self.quadrupole is None: - raise RuntimeError("Quadrupole integral not initialized.") + # pylint: disable=import-outside-toplevel + from .factory import new_quadint + + self.quadrupole = new_quadint( + self.mgr.driver_type, + **self.dd, + **kwargs, + ) + + if self.overlap is None: + self.build_overlap(positions, **kwargs) + assert self.overlap is not None # build - self.quadrupole.integral.norm = self._norm(positions, **kwargs) - self.quadrupole.build(self.driver) + self.quadrupole.build(self.mgr.driver) + self.quadrupole.normalize(self.overlap.norm) logger.debug("Quad integral: Finished building matrix.") # make traceless before shifting if traceless is True: logger.debug("Quad integral: Start creating traceless rep.") - self.quadrupole.integral.traceless() + self.quadrupole.traceless() logger.debug("Quad integral: Finished creating traceless rep.") # shift to rj (requires overlap and dipole integral) if shift is True: logger.debug("Quad integral: Start shifting operator (r0r0->rjrj).") if traceless is not True: - raise RuntimeError("Quadrupole moment must be tracelesss for shifting.") + raise RuntimeError( + "Quadrupole moment must be tracelesss for shifting. " + "Run `quadrupole.traceless()` before shifting." + ) if self.dipole is None: self.build_dipole(positions, **kwargs) assert self.dipole is not None - self.quadrupole.integral.shift_r0r0_rjrj( - self.dipole.integral.matrix, - self.overlap.integral.matrix, - self.ihelp.spread_atom_to_orbital( + self.quadrupole.shift_r0r0_rjrj( + self.dipole.matrix, + self.overlap.matrix, + self.mgr.driver.ihelp.spread_atom_to_orbital( positions, dim=-2, extra=True, @@ -386,40 +406,20 @@ def build_quadrupole( ) logger.debug("Quad integral: Finished shifting operator.") - # move integral to the correct device, but only if no other multipole - # integrals are required - if self.force_cpu_for_libcint and self._intlevel <= labels.INTLEVEL_QUADRUPOLE: - self.overlap.integral = self.overlap.integral.to(self.device) - self.overlap.integral.matrix = self.overlap.integral.matrix.to(self.device) - - self.quadrupole.integral = self.quadrupole.integral.to(self.device) - self.quadrupole.integral.matrix = self.quadrupole.integral.matrix.to( - self.device - ) + # Finally, we move the integral to the correct device, but only if + # no other multipole integrals are required. + if ( + self.mgr.force_cpu_for_libcint is True + and self.intlevel <= labels.INTLEVEL_QUADRUPOLE + ): + self.overlap = self.overlap.to(self.device) + self.quadrupole = self.quadrupole.to(self.device) if self.dipole is not None: - self.dipole.integral = self.dipole.integral.to(self.device) - self.dipole.integral.matrix = self.dipole.integral.matrix.to( - self.device - ) + self.dipole = self.dipole.to(self.device) logger.debug("Quad integral: All finished.") - return self.quadrupole.integral.matrix - - # helper - - def _norm(self, positions: Tensor, **kwargs: Any) -> Tensor: - if self.overlap is None: - raise RuntimeError("Overlap integral not initialized.") - - # shortcut for overlap integral - ovlp = self.overlap.integral - - # overlap integral required for norm and shifting - if ovlp.matrix is None or ovlp.norm is None: - self.build_overlap(positions, **kwargs) - - return ovlp.norm + return self.quadrupole.matrix # checks @@ -428,40 +428,57 @@ def checks(self) -> None: return for name in ["hcore", "overlap", "dipole", "quadrupole"]: - cls = getattr(self, "_" + name) + cls: ( + BaseHamiltonian + | OverlapIntegral + | DipoleIntegral + | QuadrupoleIntegral + | None + ) = getattr(self, f"_{name}") + if cls is None: continue - cls: HCore | Overlap | Dipole | Quadrupole - if cls.dtype != self.dtype: raise RuntimeError( f"Data type of '{cls.label}' integral ({cls.dtype}) and " - f"integral container {self.dtype} do not match." - ) - if cls.device != self.device: - raise RuntimeError( - f"Device of '{cls.label}' integral ({cls.device}) and " - f"integral container {self.device} do not match." + f"integral container ({self.dtype}) do not match." ) + if self.mgr.force_cpu_for_libcint is False: + if cls.device != self.device: + raise RuntimeError( + f"Device of '{cls.label}' integral ({cls.device}) and " + f"integral container ({self.device}) do not match." + ) if name != "hcore": - family_integral = cls.integral.family # type: ignore - family_driver = self.driver.family # type: ignore + assert not isinstance(cls, BaseHamiltonian) + + family_integral = cls.family + family_driver = self.mgr.driver.family + driver_label = self.mgr.driver if family_integral != family_driver: raise RuntimeError( - f"The '{cls.integral.label}' integral implementation " + f"The '{cls.label}' integral implementation " f"requests the '{family_integral}' family, but " - f"the integral driver '{self.driver.label}' is " - f"configured with the '{family_driver}' family.\n" + f"the integral driver '{driver_label}' is " + "configured.\n" "If you want to request the 'pytorch' implementations, " "specify the driver name in the constructors of both " "the integral container and the actual integral class." ) def reset_all(self) -> None: - self.invalidate_driver() - # TODO: Do we need to reset the specific integrals? + self.mgr.invalidate_driver() + + for slot in self.__slots__: + i = getattr(self, slot) + + if not slot.startswith("_") or i is None: + continue + + if isinstance(i, BaseIntegral) or isinstance(i, BaseHamiltonian): + i.clear() # pretty print diff --git a/src/dxtb/_src/integral/driver/__init__.py b/src/dxtb/_src/integral/driver/__init__.py index 14e7a2d74..dee14ae10 100644 --- a/src/dxtb/_src/integral/driver/__init__.py +++ b/src/dxtb/_src/integral/driver/__init__.py @@ -28,4 +28,5 @@ The `PyTorch` drivers are implemented in pure Python, but are currently only available for overlap integrals. """ -# no imports here to allow lazy loading of drivers +# no imports besides driver manager here to allow lazy loading of drivers +from .manager import DriverManager diff --git a/src/dxtb/_src/integral/driver/factory.py b/src/dxtb/_src/integral/driver/factory.py new file mode 100644 index 000000000..5b4f0e5f8 --- /dev/null +++ b/src/dxtb/_src/integral/driver/factory.py @@ -0,0 +1,125 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integral driver: Factories +========================== + +Factory functions for integral drivers. +""" + +from __future__ import annotations + +import torch + +from dxtb import IndexHelper +from dxtb._src.constants import labels +from dxtb._src.param import Param +from dxtb._src.typing import TYPE_CHECKING, Tensor + +from ..base import IntDriver + +if TYPE_CHECKING: + from .libcint import IntDriverLibcint + from .pytorch import ( + IntDriverPytorch, + IntDriverPytorchLegacy, + IntDriverPytorchNoAnalytical, + ) + +__all__ = ["new_driver"] + + +def new_driver( + name: int, + numbers: Tensor, + par: Param, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> IntDriver: + if name == labels.INTDRIVER_LIBCINT: + return new_driver_libcint(numbers, par, device=device, dtype=dtype) + + if name == labels.INTDRIVER_ANALYTICAL: + return new_driver_pytorch(numbers, par, device=device, dtype=dtype) + + if name == labels.INTDRIVER_AUTOGRAD: + return new_driver_pytorch_no_analytical( + numbers, par, device=device, dtype=dtype + ) + + if name == labels.INTDRIVER_LEGACY: + return new_driver_legacy(numbers, par, device=device, dtype=dtype) + + raise ValueError(f"Unknown integral driver '{labels.INTDRIVER_MAP[name]}'.") + + +################################################################################ + + +def new_driver_libcint( + numbers: Tensor, + par: Param, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> IntDriverLibcint: + # pylint: disable=import-outside-toplevel + from .libcint import IntDriverLibcint as _IntDriver + + ihelp = IndexHelper.from_numbers(numbers, par) + return _IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + + +################################################################################ + + +def new_driver_pytorch( + numbers: Tensor, + par: Param, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> IntDriverPytorch: + # pylint: disable=import-outside-toplevel + from .pytorch import IntDriverPytorch as _IntDriver + + ihelp = IndexHelper.from_numbers(numbers, par) + return _IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + + +def new_driver_pytorch_no_analytical( + numbers: Tensor, + par: Param, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> IntDriverPytorchNoAnalytical: + # pylint: disable=import-outside-toplevel + from .pytorch import IntDriverPytorchNoAnalytical as _IntDriver + + ihelp = IndexHelper.from_numbers(numbers, par) + return _IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + + +def new_driver_legacy( + numbers: Tensor, + par: Param, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> IntDriverPytorchLegacy: + # pylint: disable=import-outside-toplevel + from .pytorch import IntDriverPytorchLegacy as _IntDriver + + ihelp = IndexHelper.from_numbers(numbers, par) + return _IntDriver(numbers, par, ihelp, device=device, dtype=dtype) diff --git a/src/dxtb/_src/integral/driver/libcint/base.py b/src/dxtb/_src/integral/driver/libcint/base.py new file mode 100644 index 000000000..6b297e7c0 --- /dev/null +++ b/src/dxtb/_src/integral/driver/libcint/base.py @@ -0,0 +1,45 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementation: Base Classes +============================ + +Base class for ``libcint``-based integral implementations. +""" + +from __future__ import annotations + +from dxtb._src.typing import Literal + +from ...base import BaseIntegral + +__all__ = ["IntegralLibcint"] + + +class LibcintImplementation: + """ + Simple label for ``libcint``-based integral implementations. + """ + + family: Literal["libcint"] = "libcint" + """Label for integral implementation family.""" + + +class IntegralLibcint(LibcintImplementation, BaseIntegral): + """ + ``libcint``-based integral implementation. + """ diff --git a/src/dxtb/_src/integral/driver/libcint/base_driver.py b/src/dxtb/_src/integral/driver/libcint/base_driver.py deleted file mode 100644 index a4bcdfb40..000000000 --- a/src/dxtb/_src/integral/driver/libcint/base_driver.py +++ /dev/null @@ -1,128 +0,0 @@ -# This file is part of dxtb. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Driver: Base Class -================== - -Base class for a `libcint`-based integral implementation -Calculation and modification of multipole integrals. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from dxtb import IndexHelper -from dxtb._src.basis.bas import Basis -from dxtb._src.constants import labels -from dxtb._src.typing import Tensor -from dxtb._src.utils import is_basis_list - -from ...base import IntDriver - -if TYPE_CHECKING: - from .driver import IntDriverLibcint -del TYPE_CHECKING - - -__all__ = ["BaseIntDriverLibcint"] - - -class LibcintImplementation: - """ - Simple label for `libcint`-based integral implementations. - """ - - family: int = labels.INTDRIVER_LIBCINT - """Label for integral implementation family""" - - def checks(self, driver: IntDriverLibcint) -> None: - """ - Check if the type of integral driver is correct. - - Parameters - ---------- - driver : IntDriverLibcint - Integral driver for the calculation. - """ - # pylint: disable=import-outside-toplevel - from .driver import IntDriverLibcint - - if not isinstance(driver, IntDriverLibcint): - raise RuntimeError("Wrong integral driver selected.") - - -class BaseIntDriverLibcint(LibcintImplementation, IntDriver): - """ - Implementation of `libcint`-based integral driver. - """ - - family: int = labels.INTDRIVER_LIBCINT - """Label for integral implementation family""" - - def setup(self, positions: Tensor, **kwargs) -> None: - """ - Run the `libcint`-specific driver setup. - - Parameters - ---------- - positions : Tensor - Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). - """ - # pylint: disable=import-outside-toplevel - from dxtb._src.exlibs import libcint - - # setup `Basis` class if not already done - if self._basis is None: - self.basis = Basis(self.numbers, self.par, self.ihelp, **self.dd) - - # create atomic basis set in libcint format - mask = kwargs.pop("mask", None) - atombases = self.basis.create_libcint(positions, mask=mask) - - if self.ihelp.batch_mode > 0: - - # integrals do not work with a batched IndexHelper - if self.ihelp.batch_mode == 1: - # pylint: disable=import-outside-toplevel - from tad_mctc.batch import deflate - - _ihelp = [ - IndexHelper.from_numbers(deflate(number), self.par) - for number in self.numbers - ] - elif self.ihelp.batch_mode == 2: - _ihelp = [ - IndexHelper.from_numbers(number, self.par) - for number in self.numbers - ] - else: - raise ValueError(f"Unknown batch mode '{self.ihelp.batch_mode}'.") - - assert isinstance(atombases, list) - self.drv = [ - libcint.LibcintWrapper(ab, ihelp) - for ab, ihelp in zip(atombases, _ihelp) - if is_basis_list(ab) - ] - else: - assert is_basis_list(atombases) - self.drv = libcint.LibcintWrapper(atombases, self.ihelp) - - # setting positions signals successful setup; save current positions to - # catch new positions and run the required re-setup of the driver - self._positions = positions.detach().clone() diff --git a/src/dxtb/_src/integral/driver/libcint/base_implementation.py b/src/dxtb/_src/integral/driver/libcint/base_implementation.py deleted file mode 100644 index 3dd3300d2..000000000 --- a/src/dxtb/_src/integral/driver/libcint/base_implementation.py +++ /dev/null @@ -1,161 +0,0 @@ -# This file is part of dxtb. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Libcint: Base Implementation -============================ - -Base class for `libcint`-based integral implementations. -""" - -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING - -import torch - -from dxtb._src.constants import labels -from dxtb._src.typing import Self, Tensor, override - -from ...base import BaseIntegralImplementation - -if TYPE_CHECKING: - from .driver import IntDriverLibcint -del TYPE_CHECKING - - -__all__ = ["IntegralImplementationLibcint", "LibcintImplementation"] - - -class LibcintImplementation: - """ - Simple label for `libcint`-based integral implementations. - """ - - family: int = labels.INTDRIVER_LIBCINT - """Label for integral implementation family""" - - def checks(self, driver: IntDriverLibcint) -> None: - """ - Check if the type of integral driver is correct. - - Parameters - ---------- - driver : IntDriverLibcint - Integral driver for the calculation. - """ - # pylint: disable=import-outside-toplevel - from .driver import IntDriverLibcint - - if not isinstance(driver, IntDriverLibcint): - raise RuntimeError("Wrong integral driver selected.") - - -class IntegralImplementationLibcint( - LibcintImplementation, - BaseIntegralImplementation, -): - """PyTorch-based integral implementation""" - - def checks(self, driver: IntDriverLibcint) -> None: - """ - Check if the type of integral driver is correct. - - Parameters - ---------- - driver : BaseIntDriverPytorch - Integral driver for the calculation. - """ - super().checks(driver) - - # pylint: disable=import-outside-toplevel - from .driver import IntDriverLibcint - - if not isinstance(driver, IntDriverLibcint): - raise RuntimeError("Wrong integral driver selected.") - - def get_gradient(self, _: IntDriverLibcint) -> Tensor: - """ - Create the nuclear integral derivative matrix. - - Parameters - ---------- - driver : IntDriver - Integral driver for the calculation. - - Returns - ------- - Tensor - Nuclear integral derivative matrix. - """ - raise NotImplementedError( - "The `get_gradient` method is not implemented for libcint " - "integrals as it is not explicitly required." - ) - - @abstractmethod - def build(self, driver: IntDriverLibcint) -> Tensor: - """ - Calculation of the integral using libcint. - - Returns - ------- - driver : IntDriverLibcint - The integral driver for the calculation. - """ - - @override - def to(self, device: torch.device) -> Self: - """ - Returns a copy of the :class:`.IntegralImplementationLibcint` instance - on the specified device. - - This method overwrites the usual approach because the - :class:`.IntegralImplementationLibcint` class should not change the - device of the norm . - - Parameters - ---------- - device : torch.device - Device to which all associated tensors should be moved. - - Returns - ------- - BaseIntegral - A copy of the :class:`.IntegralImplementationLibcint` instance - placed on the specified device. - - Raises - ------ - RuntimeError - If the ``__slots__`` attribute is not set in the class. - """ - if self.device == device: - return self - - if len(self.__slots__) == 0: - raise RuntimeError( - f"The `to` method requires setting ``__slots__`` in the " - f"'{self.__class__.__name__}' class." - ) - - self.matrix = self.matrix.to(device) - if self._gradient is not None: - self.gradient = self.gradient.to(device) - - self.override_device(device) - return self diff --git a/src/dxtb/_src/integral/driver/libcint/dipole.py b/src/dxtb/_src/integral/driver/libcint/dipole.py index 3c9eaffbb..5ca7903aa 100644 --- a/src/dxtb/_src/integral/driver/libcint/dipole.py +++ b/src/dxtb/_src/integral/driver/libcint/dipole.py @@ -23,17 +23,16 @@ from __future__ import annotations -from tad_mctc.math import einsum - from dxtb._src.typing import Tensor +from ...types import DipoleIntegral from .driver import IntDriverLibcint from .multipole import MultipoleLibcint __all__ = ["DipoleLibcint"] -class DipoleLibcint(MultipoleLibcint): +class DipoleLibcint(DipoleIntegral, MultipoleLibcint): """ Dipole integral from atomic orbitals. """ @@ -54,48 +53,18 @@ def build(self, driver: IntDriverLibcint) -> Tensor: """ return self.multipole(driver, "r0") - def shift_r0_rj(self, overlap: Tensor, pos: Tensor) -> Tensor: - r""" - Shift the centering of the dipole integral (moment operator) from the - origin (:math:`r0 = r - (0, 0, 0)`) to atoms (ket index, - :math:`rj = r - r_j`). - - .. math:: - - \begin{align} - D &= D^{r_j} \\ - &= \langle i | r_j | j \rangle \\ - &= \langle i | r | j \rangle - r_j \langle i | j \rangle \\ - &= \langle i | r_0 | j \rangle - r_j S_{ij} \\ - &= D^{r_0} - r_j S_{ij} - \end{align} + def get_gradient(self, driver: IntDriverLibcint) -> Tensor: + """ + Calculation of dipole gradient using libcint. Parameters ---------- - r0 : Tensor - Origin centered dipole integral. - overlap : Tensor - Overlap integral. - pos : Tensor - Orbital-resolved atomic positions. - - Raises - ------ - RuntimeError - Shape mismatch between ``positions`` and `overlap`. - The positions must be orbital-resolved. + driver : IntDriverLibcint + The integral driver for the calculation. Returns ------- Tensor - Second-index (ket) atom-centered dipole integral. + Dipole gradient. """ - if pos.shape[-2] != overlap.shape[-1]: - raise RuntimeError( - "Shape mismatch between positions and overlap integral. " - "The position tensor must be spread to orbital-resolution." - ) - - shift = einsum("...jx,...ij->...xij", pos, overlap) - self.matrix = self.matrix - shift - return self.matrix + raise NotImplementedError("Gradient calculation not implemented.") diff --git a/src/dxtb/_src/integral/driver/libcint/driver.py b/src/dxtb/_src/integral/driver/libcint/driver.py index 84e99853d..844352e35 100644 --- a/src/dxtb/_src/integral/driver/libcint/driver.py +++ b/src/dxtb/_src/integral/driver/libcint/driver.py @@ -18,17 +18,83 @@ Driver: Libcint =============== -Integral driver for `libcint`. +Base class for a `libcint`-based integral implementation +Calculation and modification of multipole integrals. """ from __future__ import annotations -from .base_driver import BaseIntDriverLibcint +from dxtb import IndexHelper +from dxtb._src.basis.bas import Basis +from dxtb._src.typing import Tensor +from dxtb._src.utils import is_basis_list -__all__ = ["IntDriverLibcint"] +from ...base import IntDriver +from .base import LibcintImplementation +__all__ = ["BaseIntDriverLibcint", "IntDriverLibcint"] -class IntDriverLibcint(BaseIntDriverLibcint): + +class BaseIntDriverLibcint(LibcintImplementation, IntDriver): """ Implementation of `libcint`-based integral driver. """ + + def setup(self, positions: Tensor, **kwargs) -> None: + """ + Run the `libcint`-specific driver setup. + + Parameters + ---------- + positions : Tensor + Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). + """ + # pylint: disable=import-outside-toplevel + from dxtb._src.exlibs import libcint + + # setup `Basis` class if not already done + if self._basis is None: + self.basis = Basis(self.numbers, self.par, self.ihelp, **self.dd) + + # create atomic basis set in libcint format + mask = kwargs.pop("mask", None) + atombases = self.basis.create_libcint(positions, mask=mask) + + if self.ihelp.batch_mode > 0: + + # integrals do not work with a batched IndexHelper + if self.ihelp.batch_mode == 1: + # pylint: disable=import-outside-toplevel + from tad_mctc.batch import deflate + + _ihelp = [ + IndexHelper.from_numbers(deflate(number), self.par) + for number in self.numbers + ] + elif self.ihelp.batch_mode == 2: + _ihelp = [ + IndexHelper.from_numbers(number, self.par) + for number in self.numbers + ] + else: + raise ValueError(f"Unknown batch mode '{self.ihelp.batch_mode}'.") + + assert isinstance(atombases, list) + self.drv = [ + libcint.LibcintWrapper(ab, ihelp) + for ab, ihelp in zip(atombases, _ihelp) + if is_basis_list(ab) + ] + else: + assert is_basis_list(atombases) + self.drv = libcint.LibcintWrapper(atombases, self.ihelp) + + # setting positions signals successful setup; save current positions to + # catch new positions and run the required re-setup of the driver + self._positions = positions.detach().clone() + + +class IntDriverLibcint(BaseIntDriverLibcint): + """ + Implementation of ``libcint``-based integral driver. + """ diff --git a/src/dxtb/_src/integral/driver/libcint/multipole.py b/src/dxtb/_src/integral/driver/libcint/multipole.py index cdee7424f..afbb1da6c 100644 --- a/src/dxtb/_src/integral/driver/libcint/multipole.py +++ b/src/dxtb/_src/integral/driver/libcint/multipole.py @@ -26,21 +26,19 @@ from typing import TYPE_CHECKING from tad_mctc.batch import pack -from tad_mctc.math import einsum from dxtb._src.exlibs import libcint from dxtb._src.typing import Tensor -from .base_implementation import IntegralImplementationLibcint +from .base import IntegralLibcint if TYPE_CHECKING: from .driver import IntDriverLibcint -del TYPE_CHECKING __all__ = ["MultipoleLibcint"] -class MultipoleLibcint(IntegralImplementationLibcint): +class MultipoleLibcint(IntegralLibcint): """ Base class for multipole integrals calculated with `libcint`. """ @@ -72,14 +70,8 @@ def multipole(self, driver: IntDriverLibcint, intstring: str) -> Tensor: "Other integrals can be added to `tad-libcint`." ) - if self.norm is None: - raise RuntimeError("Norm must be set before building.") - - def _mpint(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: - integral = libcint.int1e(intstring, driver) - if self.normalize is False: - return integral - return einsum("...ij,i,j->...ij", integral, norm, norm) + def _mpint(driver: libcint.LibcintWrapper) -> Tensor: + return libcint.int1e(intstring, driver) # batched mode if driver.ihelp.batch_mode > 0: @@ -88,27 +80,11 @@ def _mpint(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: "IndexHelper on integral driver is batched, but the driver " "instance itself not." ) - if driver.ihelp.batch_mode == 1: - # pylint: disable=import-outside-toplevel - from tad_mctc.batch import deflate - - self.matrix = pack( - [ - _mpint(driver, deflate(norm)) - for driver, norm in zip(driver.drv, self.norm) - ] - ) - return self.matrix - elif driver.ihelp.batch_mode == 2: - self.matrix = pack( - [ - _mpint(driver, norm) # no deflating here - for driver, norm in zip(driver.drv, self.norm) - ] - ) - return self.matrix - raise ValueError(f"Unknown batch mode '{driver.ihelp.batch_mode}'.") + # In this version, batch mode does not matter. If we would + # normalize the integral here, we would have to deflate the norm. + self.matrix = pack([_mpint(driver) for driver in driver.drv]) + return self.matrix # single mode if not isinstance(driver.drv, libcint.LibcintWrapper): @@ -117,5 +93,5 @@ def _mpint(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: "driver instance itself seems to be batched." ) - self.matrix = _mpint(driver.drv, self.norm) + self.matrix = _mpint(driver.drv) return self.matrix diff --git a/src/dxtb/_src/integral/driver/libcint/overlap.py b/src/dxtb/_src/integral/driver/libcint/overlap.py index 2c5f68dfc..28f8ec000 100644 --- a/src/dxtb/_src/integral/driver/libcint/overlap.py +++ b/src/dxtb/_src/integral/driver/libcint/overlap.py @@ -30,19 +30,23 @@ from dxtb._src.exlibs import libcint from dxtb._src.typing import Tensor -from .base_implementation import IntegralImplementationLibcint +from ...types import OverlapIntegral +from ...utils import snorm +from .base import IntegralLibcint from .driver import IntDriverLibcint __all__ = ["OverlapLibcint"] -def snorm(ovlp: Tensor) -> Tensor: - return torch.pow(ovlp.diagonal(dim1=-1, dim2=-2), -0.5) - - -class OverlapLibcint(IntegralImplementationLibcint): +class OverlapLibcint(OverlapIntegral, IntegralLibcint): """ Overlap integral from atomic orbitals. + + Use the :meth:`build` method to calculate the overlap integral. The + returned matrix uses a custom autograd function to calculate the + backward pass with the analytical gradient. + For the full gradient, i.e., a matrix of shape ``(..., norb, norb, 3)``, + the :meth:`get_gradient` method should be used. """ def build(self, driver: IntDriverLibcint) -> Tensor: @@ -53,29 +57,20 @@ def build(self, driver: IntDriverLibcint) -> Tensor: ------- driver : IntDriverLibcint The integral driver for the calculation. + + Returns + ------- + Tensor + Overlap integral matrix of shape ``(..., norb, norb)``. """ super().checks(driver) - def fcn(driver: libcint.LibcintWrapper) -> tuple[Tensor, Tensor]: - s = libcint.overlap(driver) - norm = snorm(s) - - if self.normalize is True: - s = einsum("...ij,...i,...j->...ij", s, norm, norm) - - return s, norm - # batched mode if driver.ihelp.batch_mode > 0: assert isinstance(driver.drv, list) - slist = [] - nlist = [] - - for d in driver.drv: - mat, norm = fcn(d) - slist.append(mat) - nlist.append(norm) + slist = [libcint.overlap(d) for d in driver.drv] + nlist = [snorm(s) for s in slist] self.norm = pack(nlist) self.matrix = pack(slist) @@ -84,7 +79,8 @@ def fcn(driver: libcint.LibcintWrapper) -> tuple[Tensor, Tensor]: # single mode assert isinstance(driver.drv, libcint.LibcintWrapper) - self.matrix, self.norm = fcn(driver.drv) + self.matrix = libcint.overlap(driver.drv) + self.norm = snorm(self.matrix) return self.matrix def get_gradient(self, driver: IntDriverLibcint) -> Tensor: @@ -99,29 +95,21 @@ def get_gradient(self, driver: IntDriverLibcint) -> Tensor: Returns ------- Tensor - Overlap gradient of shape `(nb, norb, norb, 3)`. + Overlap gradient of shape ``(..., norb, norb, 3)``. """ super().checks(driver) - def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: + # build norm if not already available + if self.norm is None: + self.build(driver) + + def fcn(driver: libcint.LibcintWrapper) -> Tensor: # (3, norb, norb) grad = libcint.int1e("ipovlp", driver) - if self.normalize is False: - return -einsum("...xij->...ijx", grad) - - # normalize and move xyz dimension to last, which is required for - # the reduction (only works with extra dimension in last) - return -einsum("...xij,...i,...j->...ijx", grad, norm, norm) - - # build norm if not already available - if self.norm is None: - if driver.ihelp.batch_mode > 0: - assert isinstance(driver.drv, list) - self.norm = pack([snorm(libcint.overlap(d)) for d in driver.drv]) - else: - assert isinstance(driver.drv, libcint.LibcintWrapper) - self.norm = snorm(libcint.overlap(driver.drv)) + # Move xyz dimension to last, which is required for the + # reduction (only works with extra dimension in last) + return -einsum("...xij->...ijx", grad) # batched mode if driver.ihelp.batch_mode > 0: @@ -132,24 +120,12 @@ def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: ) if driver.ihelp.batch_mode == 1: - # pylint: disable=import-outside-toplevel - from tad_mctc.batch import deflate - - self.grad = pack( - [ - fcn(driver, deflate(norm)) - for driver, norm in zip(driver.drv, self.norm) - ] - ) - return self.grad + self.gradient = pack([fcn(d) for d in driver.drv]) + return self.gradient + elif driver.ihelp.batch_mode == 2: - self.grad = pack( - [ - fcn(driver, norm) # no deflating here - for driver, norm in zip(driver.drv, self.norm) - ] - ) - return self.grad + self.gradient = torch.stack([fcn(d) for d in driver.drv]) + return self.gradient raise ValueError(f"Unknown batch mode '{driver.ihelp.batch_mode}'.") @@ -160,5 +136,6 @@ def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: "driver instance itself seems to be batched." ) - self.grad = fcn(driver.drv, self.norm) - return self.grad + print("aksdjkasd") + self.gradient = fcn(driver.drv) + return self.gradient diff --git a/src/dxtb/_src/integral/driver/libcint/quadrupole.py b/src/dxtb/_src/integral/driver/libcint/quadrupole.py index 878da2ee8..ad5a700e6 100644 --- a/src/dxtb/_src/integral/driver/libcint/quadrupole.py +++ b/src/dxtb/_src/integral/driver/libcint/quadrupole.py @@ -23,18 +23,17 @@ from __future__ import annotations -import torch -from tad_mctc.math import einsum - -from dxtb._src.typing import Tensor +from dxtb._src.integral.base import IntDriver +from dxtb._src.typing import Any, Tensor +from ...types import QuadrupoleIntegral from .driver import IntDriverLibcint from .multipole import MultipoleLibcint __all__ = ["QuadrupoleLibcint"] -class QuadrupoleLibcint(MultipoleLibcint): +class QuadrupoleLibcint(QuadrupoleIntegral, MultipoleLibcint): """ Quadrupole integral from atomic orbitals. """ @@ -55,238 +54,5 @@ def build(self, driver: IntDriverLibcint) -> Tensor: """ return self.multipole(driver, "r0r0") - def traceless(self) -> Tensor: - """ - Make a quadrupole (integral) tensor traceless. - - Parameters - ---------- - qpint : Tensor - Quadrupole moment tensor of shape ``(..., 9, n, n)``. - - Returns - ------- - Tensor - Traceless Quadrupole moment tensor of shape - ``(..., 6, n, n)``. - - Raises - ------ - RuntimeError - Supplied quadrupole integral is no ``3x3`` tensor. - - Note - ---- - First the quadrupole tensor is reshaped to be symmetric. - Due to symmetry, only the lower triangular matrix is used. - - xx xy xz 0 1 2 0 - yx yy yz <=> 3 4 5 -> 3 4 - zx zy zz 6 7 8 6 7 8 - """ - - if self.matrix.shape[-3] != 9: - raise RuntimeError( - "Quadrupole integral must be a tensor tensor of shape " - f"'(9, nao, nao)' but is {self.matrix.shape}." - ) - - # (..., 9, norb, norb) -> (..., 3, 3, norb, norb) - shp = self.matrix.shape - qpint = self.matrix.view(*shp[:-3], 3, 3, *shp[-2:]) - - # trace: (..., 3, 3, norb, norb) -> (..., norb, norb) - tr = 0.5 * einsum("...iijk->...jk", qpint) - - self.matrix = torch.stack( - [ - 1.5 * qpint[..., 0, 0, :, :] - tr, # xx - 1.5 * qpint[..., 1, 0, :, :], # yx - 1.5 * qpint[..., 1, 1, :, :] - tr, # yy - 1.5 * qpint[..., 2, 0, :, :], # zx - 1.5 * qpint[..., 2, 1, :, :], # zy - 1.5 * qpint[..., 2, 2, :, :] - tr, # zz - ], - dim=-3, - ) - return self.matrix - - def shift_r0r0_rjrj(self, r0: Tensor, overlap: Tensor, pos: Tensor) -> Tensor: - r""" - Shift the centering of the quadrupole integral (moment operator) from - the origin (:math:`r0 = r - (0, 0, 0)`) to atoms (ket index, - :math:`rj = r - r_j`). - - Create the shift contribution for all diagonal elements of the - quadrupole integral. - - We start with the quadrupole integral generated by the ``r0`` moment - operator: - - .. math:: - - Q_{xx}^{r0} = \langle i | (r_x - r0)^2 | j \rangle = \langle i | r_x^2 | j \rangle - - Now, we shift the integral to ``r_j`` yielding the quadrupole integral - center on the respective atoms: - - .. math:: - - \begin{align} - Q_{xx} &= \langle i | (r_x - r_{xj})^2 | j \rangle \\ - &= \langle i | r_x^2 | j \rangle - 2 \langle i | r_{xj} r_x | j \rangle + \langle i | r_{xj}^2 | j \rangle \\ - &= Q_{xx}^{r0} - 2 r_{xj} \langle i | r_x | j \rangle + r_{xj}^2 \langle i | j \rangle \\ - &= Q_{xx}^{r0} - 2 r_{xj} D_{x}^{r0} + r_{xj}^2 S_{ij} - \end{align} - - Next, we create the shift contribution for all off-diagonal elements of - the quadrupole integral. - - .. math:: - - \begin{align} - Q_{ab} &= \langle i | (r_a - r_{aj})(r_b - r_{bj}) | j \rangle \\ - &= \langle i | r_a r_b | j \rangle - \langle i | r_a r_{bj} | j \rangle - \langle i | r_{aj} r_b | j \rangle + \langle i | r_{aj} r_{bj} | j \rangle \\ - &= Q_{ab}^{r0} - r_{bj} \langle i | r_a | j \rangle - r_{aj} \langle i | r_b | j \rangle + r_{aj} r_{bj} \langle i | j \rangle \\ - &= Q_{ab}^{r0} - r_{bj} D_a^{r0} - r_{aj} D_b^{r0} + r_{aj} r_{bj} S_{ij} - \end{align} - - Parameters - ---------- - r0 : Tensor - Origin-centered dipole integral. - overlap : Tensor - Monopole integral (overlap). - pos : Tensor - Orbital-resolved atomic positions. - - Raises - ------ - RuntimeError - Shape mismatch between ``positions`` and ``overlap``. - The positions must be orbital-resolved. - - Returns - ------- - Tensor - Second-index (ket) atom-centered quadrupole integral. - """ - if pos.shape[-2] != overlap.shape[-1]: - raise RuntimeError( - "Shape mismatch between positions and overlap integral. " - "The position tensor must be spread to orbital-resolution." - ) - - # cartesian components for convenience - x = pos[..., 0] - y = pos[..., 1] - z = pos[..., 2] - dpx = r0[..., 0, :, :] - dpy = r0[..., 1, :, :] - dpz = r0[..., 2, :, :] - - # construct shift contribution from dipole and monopole (overlap) moments - shift_xx = shift_diagonal(x, dpx, overlap) - shift_yy = shift_diagonal(y, dpy, overlap) - shift_zz = shift_diagonal(z, dpz, overlap) - shift_yx = shift_offdiag(y, x, dpy, dpx, overlap) - shift_zx = shift_offdiag(z, x, dpz, dpx, overlap) - shift_zy = shift_offdiag(z, y, dpz, dpy, overlap) - - # collect the trace of shift contribution - tr = 0.5 * (shift_xx + shift_yy + shift_zz) - - self.matrix = torch.stack( - [ - self.matrix[..., 0, :, :] + 1.5 * shift_xx - tr, # xx - self.matrix[..., 1, :, :] + 1.5 * shift_yx, # yx - self.matrix[..., 2, :, :] + 1.5 * shift_yy - tr, # yy - self.matrix[..., 3, :, :] + 1.5 * shift_zx, # zx - self.matrix[..., 4, :, :] + 1.5 * shift_zy, # zy - self.matrix[..., 5, :, :] + 1.5 * shift_zz - tr, # zz - ], - dim=-3, - ) - return self.matrix - - -def shift_diagonal(c: Tensor, dpc: Tensor, s: Tensor) -> Tensor: - r""" - Create the shift contribution for all diagonal elements of the quadrupole - integral. - - We start with the quadrupole integral generated by the ``r0`` moment - operator: - - .. math:: - - Q_{xx}^{r0} = \langle i | (r_x - r0)^2 | j \rangle = \langle i | r_x^2 | j \rangle - - Now, we shift the integral to ``r_j`` yielding the quadrupole integral - center on the respective atoms: - - .. math:: - - \begin{align} - Q_{xx} &= \langle i | (r_x - r_{xj})^2 | j \rangle \\ - &= \langle i | r_x^2 | j \rangle - 2 \langle i | r_{xj} r_x | j \rangle + \langle i | r_{xj}^2 | j \rangle \\ - &= Q_{xx}^{r0} - 2 r_{xj} \langle i | r_x | j \rangle + r_{xj}^2 \langle i | j \rangle \\ - &= Q_{xx}^{r0} - 2 r_{xj} D_{x}^{r0} + r_{xj}^2 S_{ij} - \end{align} - - Parameters - ---------- - c : Tensor - Cartesian component. - dpc : Tensor - Cartesian component of dipole integral (`r0` operator). - s : Tensor - Overlap integral. - - Returns - ------- - Tensor - Shift contribution for diagonals of quadrupole integral. - """ - shift_1 = -2 * einsum("...j,...ij->...ij", c, dpc) - shift_2 = einsum("...j,...j,...ij->...ij", c, c, s) - return shift_1 + shift_2 - - -def shift_offdiag(a: Tensor, b: Tensor, dpa: Tensor, dpb: Tensor, s: Tensor) -> Tensor: - r""" - Create the shift contribution for all off-diagonal elements of the - quadrupole integral. - - .. math:: - - \begin{align} - Q_{ab} &= \langle i | (r_a - r_{aj})(r_b - r_{bj}) | j \rangle \\ - &= \langle i | r_a r_b | j \rangle - \langle i | r_a r_{bj} | j \rangle - \langle i | r_{aj} r_b | j \rangle + \langle i | r_{aj} r_{bj} | j \rangle \\ - &= Q_{ab}^{r0} - r_{bj} \langle i | r_a | j \rangle - r_{aj} \langle i | r_b | j \rangle + r_{aj} r_{bj} \langle i | j \rangle \\ - &= Q_{ab}^{r0} - r_{bj} D_a^{r0} - r_{aj} D_b^{r0} + r_{aj} r_{bj} S_{ij} - \end{align} - - Parameters - ---------- - a : Tensor - First cartesian component. - b : Tensor - Second cartesian component. - dpa : Tensor - First cartesian component of dipole integral (r0 operator). - dpb : Tensor - Second cartesian component of dipole integral (r0 operator). - s : Tensor - Overlap integral. - - Returns - ------- - Tensor - Shift contribution of off-diagonal elements of quadrupole integral. - """ - shift_ab_1 = -einsum("...j,...ij->...ij", b, dpa) - shift_ab_2 = -einsum("...j,...ij->...ij", a, dpb) - shift_ab_3 = einsum("...j,...j,...ij->...ij", a, b, s) - - return shift_ab_1 + shift_ab_2 + shift_ab_3 + def get_gradient(self, driver: IntDriver, **kwargs: Any) -> Tensor: + raise NotImplementedError("Gradient calculation not implemented.") diff --git a/src/dxtb/_src/integral/driver/manager.py b/src/dxtb/_src/integral/driver/manager.py new file mode 100644 index 000000000..d7709a74a --- /dev/null +++ b/src/dxtb/_src/integral/driver/manager.py @@ -0,0 +1,132 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integrals: Driver Manager +========================= + +The driver manager contains the selection logic, i.e., the it instantiates +the appropriate driver based on the configuration. +""" +from __future__ import annotations + +import logging + +import torch + +from dxtb import IndexHelper, labels +from dxtb._src.param import Param +from dxtb._src.typing import TYPE_CHECKING, Any, Tensor, TensorLike + +if TYPE_CHECKING: + from ..base import IntDriver + + +__all__ = ["DriverManager"] + + +logger = logging.getLogger(__name__) + + +class DriverManager(TensorLike): + """ + This class instantiates the appropriate driver based on the + configuration passed to it. + """ + + __slots__ = ["_driver", "driver_type", "force_cpu_for_libcint"] + + def __init__( + self, + driver_type: int, + _driver: IntDriver | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, + ) -> None: + super().__init__(device=device, dtype=dtype) + + # per default, libcint is run on the CPU + self.force_cpu_for_libcint = kwargs.pop( + "force_cpu_for_libcint", + True if driver_type == labels.INTDRIVER_LIBCINT else False, + ) + + self.driver_type = driver_type + self._driver = _driver + + @property + def driver(self) -> IntDriver: + if self._driver is None: + raise RuntimeError( + "No driver has been created yet. Run `create_driver` first." + ) + + return self._driver + + @driver.setter + def driver(self, driver: IntDriver) -> None: + self._driver = driver + + def create_driver(self, numbers: Tensor, par: Param, ihelp: IndexHelper) -> None: + if self.driver_type == labels.INTDRIVER_LIBCINT: + # pylint: disable=import-outside-toplevel + from .libcint import IntDriverLibcint as _IntDriver + + if self.force_cpu_for_libcint is True: + device = torch.device("cpu") + numbers = numbers.to(device=device) + ihelp = ihelp.to(device=device) + + elif self.driver_type == labels.INTDRIVER_ANALYTICAL: + # pylint: disable=import-outside-toplevel + from .pytorch import IntDriverPytorch as _IntDriver + + elif self.driver_type == labels.INTDRIVER_AUTOGRAD: + # pylint: disable=import-outside-toplevel + from .pytorch import IntDriverPytorchNoAnalytical as _IntDriver + + else: + raise ValueError(f"Unknown integral driver '{self.driver_type}'.") + + self.driver = _IntDriver( + numbers, par, ihelp, device=ihelp.device, dtype=self.dtype + ) + + def setup_driver(self, positions: Tensor, **kwargs: Any) -> None: + """ + Setup the integral driver (if not already done). + + Parameters + ---------- + positions : Tensor + Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). + """ + logger.debug("Integral Driver: Start setup.") + + if self.force_cpu_for_libcint is True: + positions = positions.to(device=torch.device("cpu")) + + if self.driver.is_latest(positions) is True: + logger.debug("Integral Driver: Skip setup. Already done.") + return + + self.driver.setup(positions, **kwargs) + logger.debug("Integral Driver: Finished setup.") + + def invalidate_driver(self) -> None: + """Invalidate the integral driver to require new setup.""" + self.driver.invalidate() diff --git a/src/dxtb/_src/integral/driver/pytorch/__init__.py b/src/dxtb/_src/integral/driver/pytorch/__init__.py index 34de67904..0edfd7508 100644 --- a/src/dxtb/_src/integral/driver/pytorch/__init__.py +++ b/src/dxtb/_src/integral/driver/pytorch/__init__.py @@ -21,15 +21,19 @@ Pytorch-based integral implementations. """ +from .dipole import DipolePytorch from .driver import ( IntDriverPytorch, IntDriverPytorchLegacy, IntDriverPytorchNoAnalytical, ) from .overlap import OverlapPytorch +from .quadrupole import QuadrupolePytorch __all__ = [ "OverlapPytorch", + "DipolePytorch", + "QuadrupolePytorch", "IntDriverPytorch", "IntDriverPytorchLegacy", "IntDriverPytorchNoAnalytical", diff --git a/src/dxtb/_src/integral/driver/pytorch/base.py b/src/dxtb/_src/integral/driver/pytorch/base.py new file mode 100644 index 000000000..d3917c8fc --- /dev/null +++ b/src/dxtb/_src/integral/driver/pytorch/base.py @@ -0,0 +1,45 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementation: Base Classes +============================ + +Base class for ``PyTorch``-based drivers and integral implementations. +""" + +from __future__ import annotations + +from dxtb._src.typing import Literal + +from ...base import BaseIntegral + +__all__ = ["IntegralPytorch"] + + +class PytorchImplementation: + """ + Simple label for ``PyTorch``-based integral implementations. + """ + + family: Literal["PyTorch"] = "PyTorch" + """Label for integral implementation family.""" + + +class IntegralPytorch(PytorchImplementation, BaseIntegral): + """ + ``PyTorch``-based integral implementation. + """ diff --git a/src/dxtb/_src/integral/driver/pytorch/base_driver.py b/src/dxtb/_src/integral/driver/pytorch/base_driver.py deleted file mode 100644 index 22c6e5cd8..000000000 --- a/src/dxtb/_src/integral/driver/pytorch/base_driver.py +++ /dev/null @@ -1,146 +0,0 @@ -# This file is part of dxtb. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Driver: Base Class -================== - -Base class for PyTorch-based drivers. -""" - -from __future__ import annotations - -from abc import abstractmethod - -import torch - -from dxtb import IndexHelper -from dxtb._src.basis.bas import Basis -from dxtb._src.constants import labels -from dxtb._src.typing import Any, Tensor - -from ...base import IntDriver -from .impls import OverlapFunction - -__all__ = ["BaseIntDriverPytorch"] - - -class PytorchImplementation: - """ - Simple label for `PyTorch`-based integral implementations. - """ - - family: int = labels.INTDRIVER_ANALYTICAL - """Label for integral implementation family""" - - -class BaseIntDriverPytorch(PytorchImplementation, IntDriver): - """ - PyTorch-based integral driver. - - Note - ---- - Currently, only the overlap integral is implemented. - """ - - eval_ovlp: OverlapFunction - """Function for overlap calculation.""" - - eval_ovlp_grad: OverlapFunction - """Function for overlap gradient calculation.""" - - def setup(self, positions: Tensor, **kwargs: Any) -> None: - """ - Run the `libcint`-specific driver setup. - - Parameters - ---------- - positions : Tensor - Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). - """ - if self.ihelp.batch_mode == 0: - # setup `Basis` class if not already done - if self._basis is None: - self.basis = Basis( - torch.unique(self.numbers), - self.par, - self.ihelp, - device=self.device, - dtype=self.dtype, - ) - - self._positions_single = positions - else: - - self._positions_batch: list[Tensor] = [] - self._basis_batch: list[Basis] = [] - self._ihelp_batch: list[IndexHelper] = [] - for _batch in range(self.numbers.shape[0]): - # POSITIONS - if self.ihelp.batch_mode == 1: - # pylint: disable=import-outside-toplevel - from tad_mctc.batch import deflate - - mask = kwargs.pop("mask", None) - if mask is not None: - pos = torch.masked_select( - positions[_batch], - mask[_batch], - ).reshape((-1, 3)) - else: - pos = deflate(positions[_batch]) - - nums = deflate(self.numbers[_batch]) - - elif self.ihelp.batch_mode == 2: - pos = positions[_batch] - nums = self.numbers[_batch] - - else: - raise ValueError(f"Unknown batch mode '{self.ihelp.batch_mode}'.") - - self._positions_batch.append(pos) - - # INDEXHELPER - # unfortunately, we need a new IndexHelper for each batch, - # but this is much faster than `calc_overlap` - ihelp = IndexHelper.from_numbers(nums, self.par) - - self._ihelp_batch.append(ihelp) - - # BASIS - bas = Basis( - torch.unique(nums), - self.par, - ihelp, - dtype=self.dtype, - device=self.device, - ) - - self._basis_batch.append(bas) - - self.setup_eval_funcs() - - # setting positions signals successful setup; save current positions to - # catch new positions and run the required re-setup of the driver - self._positions = positions.detach().clone() - - @abstractmethod - def setup_eval_funcs(self) -> None: - """ - Specification of the overlap (gradient) evaluation functions - (`eval_ovlp` and `eval_ovlp_grad`). - """ diff --git a/src/dxtb/_src/integral/driver/pytorch/base_implementation.py b/src/dxtb/_src/integral/driver/pytorch/base_implementation.py deleted file mode 100644 index 6464b2da7..000000000 --- a/src/dxtb/_src/integral/driver/pytorch/base_implementation.py +++ /dev/null @@ -1,69 +0,0 @@ -# This file is part of dxtb. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -PyTorch: Base Implementation -============================ - -Base class for PyTorch-based integral implementations. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from dxtb._src.constants import labels - -from ...base import BaseIntegralImplementation - -if TYPE_CHECKING: - from .base_driver import BaseIntDriverPytorch -del TYPE_CHECKING - -__all__ = ["IntegralImplementationPytorch"] - - -class PytorchImplementation: - """ - Simple label for `PyTorch`-based integral implementations. - """ - - family: int = labels.INTDRIVER_ANALYTICAL - """Label for integral implementation family""" - - -class IntegralImplementationPytorch( - PytorchImplementation, - BaseIntegralImplementation, -): - """PyTorch-based integral implementation""" - - def checks(self, driver: BaseIntDriverPytorch) -> None: - """ - Check if the type of integral driver is correct. - - Parameters - ---------- - driver : BaseIntDriverPytorch - Integral driver for the calculation. - """ - super().checks(driver) - - # pylint: disable=import-outside-toplevel - from .base_driver import BaseIntDriverPytorch - - if not isinstance(driver, BaseIntDriverPytorch): - raise RuntimeError("Wrong integral driver selected.") diff --git a/src/dxtb/_src/integral/driver/pytorch/dipole.py b/src/dxtb/_src/integral/driver/pytorch/dipole.py new file mode 100644 index 000000000..53ef142a0 --- /dev/null +++ b/src/dxtb/_src/integral/driver/pytorch/dipole.py @@ -0,0 +1,109 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementation: Dipole +====================== + +PyTorch-based dipole integral implementations. +""" + +from __future__ import annotations + +import torch + +from dxtb._src.constants import defaults +from dxtb._src.typing import Literal, Tensor + +from ...types import DipoleIntegral +from .base import IntegralPytorch +from .driver import BaseIntDriverPytorch + +__all__ = ["DipolePytorch"] + + +class DipolePytorch(DipoleIntegral, IntegralPytorch): + """ + Dipole integral from atomic orbitals. + """ + + uplo: Literal["n", "u", "l"] = "l" + """ + Whether the matrix of unique shell pairs should be create as a + triangular matrix (``l``: lower, ``u``: upper) or full matrix (``n``). + Defaults to ``l`` (lower triangular matrix). + """ + + cutoff: Tensor | float | int | None = defaults.INTCUTOFF + """ + Real-space cutoff for integral calculation in Bohr. Defaults to + ``constants.defaults.INTCUTOFF``. + """ + + def __init__( + self, + uplo: Literal["n", "N", "u", "U", "l", "L"] = "l", + cutoff: Tensor | float | int | None = defaults.INTCUTOFF, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + super().__init__(device=device, dtype=dtype) + self.cutoff = cutoff + + if uplo not in ("n", "N", "u", "U", "l", "L"): + raise ValueError(f"Unknown option for `uplo` chosen: '{uplo}'.") + self.uplo = uplo.casefold() # type: ignore + + raise NotImplementedError( + "PyTorch versions of multipole moments are not implemented. " + "Use `libcint` as integral driver." + ) + + def build(self, driver: BaseIntDriverPytorch) -> Tensor: + """ + Integral calculation of unique shells pairs, using the + McMurchie-Davidson algorithm. + + Parameters + ---------- + driver : BaseIntDriverPytorch + Integral driver for the calculation. + + Returns + ------- + Tensor + Integral matrix of shape ``(..., norb, norb, 3)``. + """ + super().checks(driver) + raise NotImplementedError + + def get_gradient(self, driver: BaseIntDriverPytorch) -> Tensor: + """ + Dipole intgral gradient calculation of unique shells pairs, using the + McMurchie-Davidson algorithm. + + Parameters + ---------- + driver : BaseIntDriverPytorch + Integral driver for the calculation. + + Returns + ------- + Tensor + Integral gradient of shape ``(..., norb, norb, 3, 3)``. + """ + super().checks(driver) + raise NotImplementedError diff --git a/src/dxtb/_src/integral/driver/pytorch/driver.py b/src/dxtb/_src/integral/driver/pytorch/driver.py index 262ed5eab..bac183e3a 100644 --- a/src/dxtb/_src/integral/driver/pytorch/driver.py +++ b/src/dxtb/_src/integral/driver/pytorch/driver.py @@ -23,7 +23,16 @@ from __future__ import annotations -from .base_driver import BaseIntDriverPytorch +from abc import abstractmethod + +import torch + +from dxtb import IndexHelper +from dxtb._src.basis.bas import Basis +from dxtb._src.typing import Any, Tensor + +from ...base import IntDriver +from .base import PytorchImplementation from .impls import ( OverlapAG_V1, OverlapAG_V2, @@ -33,12 +42,112 @@ ) __all__ = [ + "BaseIntDriverPytorch", "IntDriverPytorch", "IntDriverPytorchNoAnalytical", "IntDriverPytorchLegacy", ] +class BaseIntDriverPytorch(PytorchImplementation, IntDriver): + """ + PyTorch-based integral driver. + + Note + ---- + Currently, only the overlap integral is implemented. + """ + + eval_ovlp: OverlapFunction + """Function for overlap calculation.""" + + eval_ovlp_grad: OverlapFunction + """Function for overlap gradient calculation.""" + + def setup(self, positions: Tensor, **kwargs: Any) -> None: + """ + Run the `libcint`-specific driver setup. + + Parameters + ---------- + positions : Tensor + Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). + """ + if self.ihelp.batch_mode == 0: + # setup `Basis` class if not already done + if self._basis is None: + self.basis = Basis( + torch.unique(self.numbers), + self.par, + self.ihelp, + device=self.device, + dtype=self.dtype, + ) + + self._positions_single = positions + else: + + self._positions_batch: list[Tensor] = [] + self._basis_batch: list[Basis] = [] + self._ihelp_batch: list[IndexHelper] = [] + for _batch in range(self.numbers.shape[0]): + # POSITIONS + if self.ihelp.batch_mode == 1: + # pylint: disable=import-outside-toplevel + from tad_mctc.batch import deflate + + mask = kwargs.pop("mask", None) + if mask is not None: + pos = torch.masked_select( + positions[_batch], + mask[_batch], + ).reshape((-1, 3)) + else: + pos = deflate(positions[_batch]) + + nums = deflate(self.numbers[_batch]) + + elif self.ihelp.batch_mode == 2: + pos = positions[_batch] + nums = self.numbers[_batch] + + else: + raise ValueError(f"Unknown batch mode '{self.ihelp.batch_mode}'.") + + self._positions_batch.append(pos) + + # INDEXHELPER + # unfortunately, we need a new IndexHelper for each batch, + # but this is much faster than `calc_overlap` + ihelp = IndexHelper.from_numbers(nums, self.par) + + self._ihelp_batch.append(ihelp) + + # BASIS + bas = Basis( + torch.unique(nums), + self.par, + ihelp, + dtype=self.dtype, + device=self.device, + ) + + self._basis_batch.append(bas) + + self.setup_eval_funcs() + + # setting positions signals successful setup; save current positions to + # catch new positions and run the required re-setup of the driver + self._positions = positions.detach().clone() + + @abstractmethod + def setup_eval_funcs(self) -> None: + """ + Specification of the overlap (gradient) evaluation functions + (`eval_ovlp` and `eval_ovlp_grad`). + """ + + class IntDriverPytorch(BaseIntDriverPytorch): """ PyTorch-based integral driver. diff --git a/src/dxtb/_src/integral/driver/pytorch/impls/md/explicit.py b/src/dxtb/_src/integral/driver/pytorch/impls/md/explicit.py index e36585874..709b5ef4c 100644 --- a/src/dxtb/_src/integral/driver/pytorch/impls/md/explicit.py +++ b/src/dxtb/_src/integral/driver/pytorch/impls/md/explicit.py @@ -803,7 +803,7 @@ def de_p( e040 = rpj * e030 + e031 f030 = 3 * e020 - b * e040 - e130 = rpi * e030 + e031 + # e130 = rpi * e030 + e031 e032 = xij * e021 + rpj * e022 e131 = xij * e030 + rpi * e031 + 2 * e032 e230 = rpi * e130 + e131 @@ -1493,7 +1493,7 @@ def de_f( e202 = xij * e101 e301 = xij * e200 + rpi * e201 + 2 * e202 - e310 = rpj * e300 + e301 + # e310 = rpj * e300 + e301 e302 = xij * e201 + rpi * e202 e311 = xij * e300 + rpj * e301 + 2 * e302 @@ -1525,7 +1525,7 @@ def de_f( f120 = a * e220 - e020 - e320 = rpj * e310 + e311 + # e320 = rpj * e310 + e311 f220 = a * e320 - 2 * e120 e303 = xij * e202 @@ -1575,7 +1575,7 @@ def de_f( e202 = xij * e101 e301 = xij * e200 + rpi * e201 + 2 * e202 - e310 = rpj * e300 + e301 + # e310 = rpj * e300 + e301 e021 = xij * e010 + rpj * e011 e030 = rpj * e020 + e021 @@ -1583,7 +1583,7 @@ def de_f( e022 = xij * e011 e031 = xij * e020 + rpj * e021 + 2 * e022 e301 = xij * e200 + rpi * e201 + 2 * e202 - e310 = rpj * e300 + e301 + # e310 = rpj * e300 + e301 e130 = rpi * e030 + e031 e032 = xij * e021 + rpj * e022 diff --git a/src/dxtb/_src/integral/driver/pytorch/impls/md/recursion.py b/src/dxtb/_src/integral/driver/pytorch/impls/md/recursion.py index 2006b358a..b8b368074 100644 --- a/src/dxtb/_src/integral/driver/pytorch/impls/md/recursion.py +++ b/src/dxtb/_src/integral/driver/pytorch/impls/md/recursion.py @@ -457,7 +457,7 @@ def md_recursion_gradient( raise IntegralTransformError() from e # cartesian overlap and overlap gradient - s3d = vec.new_zeros(*[*vec.shape[:-1], ncarti, ncartj]) + # s3d = vec.new_zeros(*[*vec.shape[:-1], ncarti, ncartj]) ds3d = vec.new_zeros(*[*vec.shape[:-1], 3, ncarti, ncartj]) ai, aj = alpha[0].unsqueeze(-1), alpha[1].unsqueeze(-2) diff --git a/src/dxtb/_src/integral/driver/pytorch/impls/md/trafo.py b/src/dxtb/_src/integral/driver/pytorch/impls/md/trafo.py index c7c2593bc..bc6a8bac4 100644 --- a/src/dxtb/_src/integral/driver/pytorch/impls/md/trafo.py +++ b/src/dxtb/_src/integral/driver/pytorch/impls/md/trafo.py @@ -44,13 +44,13 @@ s6 = sqrt(6.0) s15 = sqrt(15.0) s15_4 = sqrt(15.0 / 4.0) -s45 = sqrt(45.0) +# s45 = sqrt(45.0) s45_8 = sqrt(45.0 / 8.0) # d38 = 3.0 / 8.0 # d34 = 3.0 / 4.0 # s5_16 = sqrt(5.0 / 16.0) -s10 = sqrt(10.0) +# s10 = sqrt(10.0) # s10_8 = sqrt(10.0 / 8.0) # s35_4 = sqrt(35.0 / 4.0) # s35_8 = sqrt(35.0 / 8.0) diff --git a/src/dxtb/_src/integral/driver/pytorch/overlap.py b/src/dxtb/_src/integral/driver/pytorch/overlap.py index edd2016f4..1a4ceba39 100644 --- a/src/dxtb/_src/integral/driver/pytorch/overlap.py +++ b/src/dxtb/_src/integral/driver/pytorch/overlap.py @@ -29,22 +29,23 @@ from dxtb._src.constants import defaults from dxtb._src.typing import Literal, Tensor -from .base_implementation import IntegralImplementationPytorch +from ...types import OverlapIntegral +from .base import IntegralPytorch from .driver import BaseIntDriverPytorch from .impls import OverlapFunction __all__ = ["OverlapPytorch"] -class OverlapPytorch(IntegralImplementationPytorch): +class OverlapPytorch(OverlapIntegral, IntegralPytorch): """ - Overlap from atomic orbitals. + Overlap integral from atomic orbitals. - Use the `build()` method to calculate the overlap integral. The returned - matrix uses a custom autograd function to calculate the backward pass with - the analytical gradient. - For the full gradient, i.e., a matrix of shape `(nb, norb, norb, 3)`, the - `get_gradient()` method should be used. + Use the :meth:`.build` method to calculate the overlap integral. The + returned matrix uses a custom autograd function to calculate the + backward pass with the analytical gradient. + For the full gradient, i.e., a matrix of shape ``(..., norb, norb, 3)``, + the :meth:`.get_gradient` method should be used. """ uplo: Literal["n", "u", "l"] = "l" @@ -82,12 +83,12 @@ def build(self, driver: BaseIntDriverPytorch) -> Tensor: Parameters ---------- driver : BaseIntDriverPytorch - Integral driver for the calculation. + The integral driver for the calculation. Returns ------- Tensor - Overlap matrix. + Overlap integral matrix of shape ``(..., norb, norb)``. """ super().checks(driver) @@ -115,16 +116,16 @@ def get_gradient(self, driver: BaseIntDriverPytorch) -> Tensor: Returns ------- Tensor - Overlap gradient of shape `(nb, norb, norb, 3)`. + Overlap gradient of shape ``(..., norb, norb, 3)``. """ super().checks(driver) if driver.ihelp.batch_mode > 0: - self.grad = self._batch(driver.eval_ovlp_grad, driver) + self.gradient = self._batch(driver.eval_ovlp_grad, driver) else: - self.grad = self._single(driver.eval_ovlp_grad, driver) + self.gradient = self._single(driver.eval_ovlp_grad, driver) - return self.grad + return self.gradient def _single(self, fcn: OverlapFunction, driver: BaseIntDriverPytorch) -> Tensor: if not isinstance(driver, BaseIntDriverPytorch): diff --git a/src/dxtb/_src/integral/driver/pytorch/quadrupole.py b/src/dxtb/_src/integral/driver/pytorch/quadrupole.py new file mode 100644 index 000000000..db9a06db8 --- /dev/null +++ b/src/dxtb/_src/integral/driver/pytorch/quadrupole.py @@ -0,0 +1,109 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementation: Quadrupole +========================== + +PyTorch-based quadrupole integral implementations. +""" + +from __future__ import annotations + +import torch + +from dxtb._src.constants import defaults +from dxtb._src.typing import Literal, Tensor + +from ...types import QuadrupoleIntegral +from .base import IntegralPytorch +from .driver import BaseIntDriverPytorch + +__all__ = ["QuadrupolePytorch"] + + +class QuadrupolePytorch(QuadrupoleIntegral, IntegralPytorch): + """ + Quadrupole integral from atomic orbitals. + """ + + uplo: Literal["n", "u", "l"] = "l" + """ + Whether the matrix of unique shell pairs should be create as a + triangular matrix (``l``: lower, ``u``: upper) or full matrix (``n``). + Defaults to ``l`` (lower triangular matrix). + """ + + cutoff: Tensor | float | int | None = defaults.INTCUTOFF + """ + Real-space cutoff for integral calculation in Bohr. Defaults to + ``constants.defaults.INTCUTOFF``. + """ + + def __init__( + self, + uplo: Literal["n", "N", "u", "U", "l", "L"] = "l", + cutoff: Tensor | float | int | None = defaults.INTCUTOFF, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + super().__init__(device=device, dtype=dtype) + self.cutoff = cutoff + + if uplo not in ("n", "N", "u", "U", "l", "L"): + raise ValueError(f"Unknown option for `uplo` chosen: '{uplo}'.") + self.uplo = uplo.casefold() # type: ignore + + raise NotImplementedError( + "PyTorch versions of multipole moments are not implemented. " + "Use `libcint` as integral driver." + ) + + def build(self, driver: BaseIntDriverPytorch) -> Tensor: + """ + Integral calculation of unique shells pairs, using the + McMurchie-Davidson algorithm. + + Parameters + ---------- + driver : BaseIntDriverPytorch + Integral driver for the calculation. + + Returns + ------- + Tensor + Integral matrix of shape ``(..., norb, norb, 3)``. + """ + super().checks(driver) + raise NotImplementedError + + def get_gradient(self, driver: BaseIntDriverPytorch) -> Tensor: + """ + Quadrupole intgral gradient calculation of unique shells pairs, using the + McMurchie-Davidson algorithm. + + Parameters + ---------- + driver : BaseIntDriverPytorch + Integral driver for the calculation. + + Returns + ------- + Tensor + Integral gradient of shape ``(..., norb, norb, 3, 3)``. + """ + super().checks(driver) + raise NotImplementedError diff --git a/src/dxtb/_src/integral/factory.py b/src/dxtb/_src/integral/factory.py index ea195448f..0adb5f89d 100644 --- a/src/dxtb/_src/integral/factory.py +++ b/src/dxtb/_src/integral/factory.py @@ -18,7 +18,7 @@ Factories ========= -Factory functions for integral drivers. +Factory functions for integral classes. """ from __future__ import annotations @@ -30,136 +30,222 @@ from dxtb._src.param import Param from dxtb._src.typing import TYPE_CHECKING, Any, Tensor -from .base import IntDriver - if TYPE_CHECKING: - from .driver.libcint import IntDriverLibcint, OverlapLibcint - from .driver.pytorch import ( - IntDriverPytorch, - IntDriverPytorchLegacy, - IntDriverPytorchNoAnalytical, - OverlapPytorch, - ) + from dxtb._src.xtb.gfn1 import GFN1Hamiltonian + from dxtb._src.xtb.gfn2 import GFN2Hamiltonian + + from .driver.libcint import DipoleLibcint, OverlapLibcint, QuadrupoleLibcint + from .driver.pytorch import DipolePytorch, OverlapPytorch, QuadrupolePytorch + from .types import DipoleIntegral, OverlapIntegral, QuadrupoleIntegral + +__all__ = ["new_hcore", "new_overlap", "new_dipint", "new_quadint"] -__all__ = ["new_driver", "new_overlap"] + +################################################################################ -def new_driver( - name: int, +def new_hcore( numbers: Tensor, par: Param, + ihelp: IndexHelper, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> IntDriver: - if name == labels.INTDRIVER_LIBCINT: - return new_driver_libcint(numbers, par, device=device, dtype=dtype) +) -> GFN1Hamiltonian | GFN2Hamiltonian: + if par.meta is None: + raise ValueError( + "The `meta` information field is missing in the parametrization. " + "No xTB core Hamiltonian can be selected and instantiated." + ) - if name == labels.INTDRIVER_ANALYTICAL: - return new_driver_pytorch(numbers, par, device=device, dtype=dtype) + if par.meta.name is None: + raise ValueError( + "The `name` field of the meta information is missing in the " + "parametrization. No xTB core Hamiltonian can be selected and " + "instantiated." + ) - if name == labels.INTDRIVER_AUTOGRAD: - return new_driver_pytorch2(numbers, par, device=device, dtype=dtype) + if par.meta.name.casefold() in ("gfn1-xtb", "gfn1"): + return new_hcore_gfn1(numbers, ihelp, par, device=device, dtype=dtype) - if name == labels.INTDRIVER_LEGACY: - return new_driver_legacy(numbers, par, device=device, dtype=dtype) + if par.meta.name.casefold() in ("gfn2-xtb", "gfn2"): + return new_hcore_gfn2(numbers, ihelp, par, device=device, dtype=dtype) - raise ValueError(f"Unknown integral driver '{ labels.INTDRIVER_MAP[name]}'.") + raise ValueError(f"Unsupported Hamiltonian type: {par.meta.name}") -def new_driver_libcint( +def new_hcore_gfn1( numbers: Tensor, - par: Param, + ihelp: IndexHelper, + par: Param | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> IntDriverLibcint: +) -> GFN1Hamiltonian: # pylint: disable=import-outside-toplevel - from .driver.libcint import IntDriverLibcint as IntDriver + from dxtb._src.xtb.gfn1 import GFN1Hamiltonian as Hamiltonian - ihelp = IndexHelper.from_numbers(numbers, par) - return IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + if par is None: + # pylint: disable=import-outside-toplevel + from dxtb import GFN1_XTB as par + return Hamiltonian(numbers, par, ihelp, device=device, dtype=dtype) -def new_driver_pytorch( + +def new_hcore_gfn2( numbers: Tensor, - par: Param, + ihelp: IndexHelper, + par: Param | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> IntDriverPytorch: +) -> GFN2Hamiltonian: # pylint: disable=import-outside-toplevel - from .driver.pytorch import IntDriverPytorch as IntDriver + from dxtb._src.xtb.gfn2 import GFN2Hamiltonian as Hamiltonian - ihelp = IndexHelper.from_numbers(numbers, par) - return IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + if par is None: + # pylint: disable=import-outside-toplevel + from dxtb import GFN2_XTB as par + return Hamiltonian(numbers, par, ihelp, device=device, dtype=dtype) -def new_driver_pytorch2( - numbers: Tensor, - par: Param, + +################################################################################ + + +def new_overlap( + driver: int = labels.INTDRIVER_LIBCINT, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> IntDriverPytorchNoAnalytical: + **kwargs: Any, +) -> OverlapIntegral: + # Determine which integral class to instantiate based on the type + if driver == labels.INTDRIVER_LIBCINT: + return new_overlap_libcint(device=device, dtype=dtype, **kwargs) + + if driver in ( + labels.INTDRIVER_ANALYTICAL, + labels.INTDRIVER_AUTOGRAD, + labels.INTDRIVER_LEGACY, + ): + return new_overlap_pytorch(device=device, dtype=dtype, **kwargs) + + raise ValueError(f"Unknown integral driver '{driver}'.") + + +def new_overlap_libcint( + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, +) -> OverlapLibcint: # pylint: disable=import-outside-toplevel - from .driver.pytorch import IntDriverPytorchNoAnalytical as IntDriver + from .driver.libcint import OverlapLibcint as Overlap - ihelp = IndexHelper.from_numbers(numbers, par) - return IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + if kwargs.pop("force_cpu_for_libcint", True): + device = torch.device("cpu") + return Overlap(device=device, dtype=dtype, **kwargs) -def new_driver_legacy( - numbers: Tensor, - par: Param, + +def new_overlap_pytorch( + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, +) -> OverlapPytorch: + # pylint: disable=import-outside-toplevel + from .driver.pytorch import OverlapPytorch as Overlap + + return Overlap(device=device, dtype=dtype, **kwargs) + + +################################################################################ + + +def new_dipint( + driver: int = labels.INTDRIVER_LIBCINT, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> IntDriverPytorchLegacy: + **kwargs: Any, +) -> DipoleIntegral: + # Determine which integral class to instantiate based on the type + if driver == labels.INTDRIVER_LIBCINT: + return new_dipint_libcint(device=device, dtype=dtype, **kwargs) + + if driver in ( + labels.INTDRIVER_ANALYTICAL, + labels.INTDRIVER_AUTOGRAD, + labels.INTDRIVER_LEGACY, + ): + return new_dipint_pytorch(device=device, dtype=dtype, **kwargs) + + raise ValueError(f"Unknown integral driver '{driver}'.") + + +def new_dipint_libcint( + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, +) -> DipoleLibcint: + # pylint: disable=import-outside-toplevel + from .driver.libcint import DipoleLibcint as _Dipole + + if kwargs.pop("force_cpu_for_libcint", True): + device = torch.device("cpu") + + return _Dipole(device=device, dtype=dtype, **kwargs) + + +def new_dipint_pytorch( + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, +) -> DipolePytorch: # pylint: disable=import-outside-toplevel - from .driver.pytorch import IntDriverPytorchLegacy as IntDriver + from .driver.pytorch import DipolePytorch as _Dipole - ihelp = IndexHelper.from_numbers(numbers, par) - return IntDriver(numbers, par, ihelp, device=device, dtype=dtype) + return _Dipole(device=device, dtype=dtype, **kwargs) ################################################################################ -def new_overlap( +def new_quadint( driver: int = labels.INTDRIVER_LIBCINT, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, -) -> OverlapLibcint | OverlapPytorch: - # Determine which overlap class to instantiate based on the type +) -> QuadrupoleIntegral: + # Determine which integral class to instantiate based on the type if driver == labels.INTDRIVER_LIBCINT: - if kwargs.pop("force_cpu_for_libcint", True): - device = torch.device("cpu") - return new_overlap_libcint(device=device, dtype=dtype, **kwargs) + return new_quadint_libcint(device=device, dtype=dtype, **kwargs) if driver in ( labels.INTDRIVER_ANALYTICAL, labels.INTDRIVER_AUTOGRAD, labels.INTDRIVER_LEGACY, ): - return new_overlap_pytorch(device=device, dtype=dtype, **kwargs) + return new_quadint_pytorch(device=device, dtype=dtype, **kwargs) raise ValueError(f"Unknown integral driver '{driver}'.") -def new_overlap_libcint( +def new_quadint_libcint( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, -) -> OverlapLibcint: +) -> QuadrupoleLibcint: # pylint: disable=import-outside-toplevel - from .driver.libcint import OverlapLibcint + from .driver.libcint import QuadrupoleLibcint as Quadrupole - return OverlapLibcint(device=device, dtype=dtype, **kwargs) + if kwargs.pop("force_cpu_for_libcint", True): + device = torch.device("cpu") + return Quadrupole(device=device, dtype=dtype, **kwargs) -def new_overlap_pytorch( + +def new_quadint_pytorch( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, -) -> OverlapPytorch: +) -> QuadrupolePytorch: # pylint: disable=import-outside-toplevel - from .driver.pytorch import OverlapPytorch + from .driver.pytorch import QuadrupolePytorch as Quadrupole - return OverlapPytorch(device=device, dtype=dtype, **kwargs) + return Quadrupole(device=device, dtype=dtype, **kwargs) diff --git a/src/dxtb/_src/integral/types/__init__.py b/src/dxtb/_src/integral/types/__init__.py index ecc9256bb..808067cb5 100644 --- a/src/dxtb/_src/integral/types/__init__.py +++ b/src/dxtb/_src/integral/types/__init__.py @@ -22,12 +22,12 @@ Currently, the following integral types are supported: -- :class:`.HCore` (core Hamiltonian) - :class:`.Overlap` - :class:`.Dipole` - :class:`.Quadrupole` + +Note that the Hamiltonian is different as it does not require a driver. """ from .dipole import * -from .h0 import * from .overlap import * from .quadrupole import * diff --git a/src/dxtb/_src/integral/types/base.py b/src/dxtb/_src/integral/types/base.py deleted file mode 100644 index 9be09248f..000000000 --- a/src/dxtb/_src/integral/types/base.py +++ /dev/null @@ -1,243 +0,0 @@ -""" -Integral Types: Base -==================== - -Base class for Integrals. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -import torch -from tad_mctc.exceptions import DtypeError - -from dxtb._src.typing import Any, PathLike, Tensor, TensorLike, override - -if TYPE_CHECKING: - from ..base import BaseIntegralImplementation, IntDriver -del TYPE_CHECKING - -__all__ = ["BaseIntegral"] - - -class IntegralABC(ABC): - """ - Abstract base class for integrals. - - This class works as a wrapper for the actual integral, which is stored in - the `integral` attribute of this class. - """ - - @abstractmethod - def build(self, positions: Tensor, **kwargs: Any) -> Tensor: - """ - Create the integral matrix. This method only calls the `build` method - of the underlying `BaseIntegralType`. - - Parameters - ---------- - positions : Tensor - Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). - - Returns - ------- - Tensor - Integral matrix. - - Note - ---- - The matrix is returned and also saved internally in the `mat` attribute. - """ - - -class BaseIntegral(IntegralABC, TensorLike): - """ - Base class for integrals. - - This class works as a wrapper for the actual integral, which is stored in - the `integral` attribute of this class. - """ - - label: str - """Identifier label for integral type.""" - - integral: BaseIntegralImplementation - """Instance of actual integral type.""" - - __slots__ = ["integral"] - - def __init__( - self, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - ) -> None: - super().__init__(device=device, dtype=dtype) - self.label = self.__class__.__name__ - - def build(self, driver: IntDriver) -> Tensor: - """ - Calculation of the integral (matrix). This method only calls the - :meth:`build` method of the underlying - :class:`BaseIntegralImplementation`. - - Parameters - ---------- - driver : IntDriver - The integral driver for the calculation. - - Returns - ------- - Tensor - Integral matrix. - """ - return self.integral.build(driver) - - def get_gradient(self, driver: IntDriver) -> Tensor: - """ - Calculation of the nuclear integral derivative (matrix). This method - only calls the :meth:`get_gradient` method of the underlying - :class:`BaseIntegralImplementation`. - - Parameters - ---------- - driver : IntDriver - The integral driver for the calculation. - - Returns - ------- - Tensor - Nuclear integral derivative matrix. - """ - return self.integral.get_gradient(driver) - - @property - def matrix(self) -> Tensor | None: - """ - Shortcut for matrix representation of the integral. - - Returns - ------- - Tensor | None - Integral matrix or ``None`` if not calculated yet. - """ - return self.integral.matrix - - @matrix.setter - def matrix(self, mat: Tensor) -> None: - """ - Shortcut for matrix representation of the integral. - - Parameters - ---------- - mat : Tensor - Integral matrix. - """ - self.integral.matrix = mat - - def to_pt(self, path: PathLike | None = None) -> None: - """ - Save the integral matrix to a file. - - Parameters - ---------- - path : PathLike | None - Path to the file where the integral matrix should be saved. If - ``None``, the matrix is saved to the default location. - """ - self.integral.to_pt(path) - - def type(self, dtype: torch.dtype) -> BaseIntegral: - """ - Returns a copy of the :class:`BaseIntegral` instance with specified - floating point type. - - This method overwrites the usual approach because the - :class:`BaseIntegral` class only contains the integral, which has to be - moved. - - Parameters - ---------- - dtype : torch.dtype - Floating point type. - - Returns - ------- - BaseIntegral - A copy of the :class:`BaseIntegral` instance with the specified - dtype. - - Raises - ------ - RuntimeError - If the ``__slots__`` attribute is not set in the class. - DtypeError - If the specified dtype is not allowed. - """ - if self.dtype == dtype: - return self - - if len(self.__slots__) == 0: - raise RuntimeError( - f"The `type` method requires setting ``__slots__`` in the " - f"'{self.__class__.__name__}' class." - ) - - if dtype not in self.allowed_dtypes: - raise DtypeError( - f"Only '{self.allowed_dtypes}' allowed (received '{dtype}')." - ) - - self.integral = self.integral.type(dtype) - self.override_dtype(dtype) - return self - - @override - def to(self, device: torch.device) -> BaseIntegral: - """ - Returns a copy of the :class:`.BaseIntegral` instance on the specified - device. - - This method overwrites the usual approach because the - :class:`.BaseIntegral` class only contains the integral, which has to be - moved. - - Parameters - ---------- - device : torch.device - Device to which all associated tensors should be moved. - - Returns - ------- - BaseIntegral - A copy of the :class:`.BaseIntegral` instance placed on the - specified device. - - Raises - ------ - RuntimeError - If the ``__slots__`` attribute is not set in the class. - """ - if self.device == device: - return self - - if len(self.__slots__) == 0: - raise RuntimeError( - f"The `to` method requires setting ``__slots__`` in the " - f"'{self.__class__.__name__}' class." - ) - - self.integral = self.integral.to(device) - self.override_device(device) - return self - - def __str__(self) -> str: - mat = self.integral._matrix - matinfo = mat.shape if mat is not None else None - d = {**self.__dict__, "matrix": matinfo} - d.pop("label") - return f"{self.label}({d})" - - def __repr__(self) -> str: - return str(self) diff --git a/src/dxtb/_src/integral/types/dipole.py b/src/dxtb/_src/integral/types/dipole.py index 207e616ed..ba8f7c87d 100644 --- a/src/dxtb/_src/integral/types/dipole.py +++ b/src/dxtb/_src/integral/types/dipole.py @@ -15,54 +15,70 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +Integral Types: Dipole +====================== + Dipole integral. """ from __future__ import annotations -import torch - -from dxtb._src.constants import labels -from dxtb._src.typing import TYPE_CHECKING, Any +from tad_mctc.math import einsum -from .base import BaseIntegral +from dxtb._src.typing import Tensor -if TYPE_CHECKING: - from ..driver.libcint import DipoleLibcint +from ..base import BaseIntegral -__all__ = ["Dipole"] +__all__ = ["DipoleIntegral"] -class Dipole(BaseIntegral): +class DipoleIntegral(BaseIntegral): """ Dipole integral from atomic orbitals. """ - integral: DipoleLibcint - """Instance of actual dipole integral type.""" + def shift_r0_rj(self, overlap: Tensor, pos: Tensor) -> Tensor: + r""" + Shift the centering of the dipole integral (moment operator) from the + origin (:math:`r0 = r - (0, 0, 0)`) to atoms (ket index, + :math:`rj = r - r_j`). - def __init__( - self, - driver: int = labels.INTDRIVER_LIBCINT, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - **kwargs: Any, - ) -> None: - super().__init__(device=device, dtype=dtype) + .. math:: - # Determine which overlap class to instantiate based on the type - if driver == labels.INTDRIVER_LIBCINT: - # pylint: disable=import-outside-toplevel - from ..driver.libcint import DipoleLibcint + \begin{align} + D &= D^{r_j} \\ + &= \langle i | r_j | j \rangle \\ + &= \langle i | r | j \rangle - r_j \langle i | j \rangle \\ + &= \langle i | r_0 | j \rangle - r_j S_{ij} \\ + &= D^{r_0} - r_j S_{ij} + \end{align} - if kwargs.pop("force_cpu_for_libcint", True): - device = torch.device("cpu") + Parameters + ---------- + r0 : Tensor + Origin centered dipole integral. + overlap : Tensor + Overlap integral. + pos : Tensor + Orbital-resolved atomic positions. - self.integral = DipoleLibcint(device=device, dtype=dtype, **kwargs) - elif driver in (labels.INTDRIVER_ANALYTICAL, labels.INTDRIVER_AUTOGRAD): - raise NotImplementedError( - "PyTorch versions of multipole moments are not implemented. " - "Use `libcint` as integral driver." + Raises + ------ + RuntimeError + Shape mismatch between ``positions`` and `overlap`. + The positions must be orbital-resolved. + + Returns + ------- + Tensor + Second-index (ket) atom-centered dipole integral. + """ + if pos.shape[-2] != overlap.shape[-1]: + raise RuntimeError( + "Shape mismatch between positions and overlap integral. " + "The position tensor must be spread to orbital-resolution." ) - else: - raise ValueError(f"Unknown integral driver '{driver}'.") + + shift = einsum("...jx,...ij->...xij", pos, overlap) + self.matrix = self.matrix - shift + return self.matrix diff --git a/src/dxtb/_src/integral/types/h0.py b/src/dxtb/_src/integral/types/h0.py deleted file mode 100644 index e9254319e..000000000 --- a/src/dxtb/_src/integral/types/h0.py +++ /dev/null @@ -1,67 +0,0 @@ -# This file is part of dxtb. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Core Hamiltonian. -""" - -from __future__ import annotations - -import torch - -from dxtb import IndexHelper -from dxtb._src.param import Param -from dxtb._src.typing import Tensor -from dxtb._src.xtb.gfn1 import GFN1Hamiltonian -from dxtb._src.xtb.gfn2 import GFN2Hamiltonian - -from .base import BaseIntegral - -__all__ = ["HCore"] - - -class HCore(BaseIntegral): - """ - Hamiltonian integral. - """ - - integral: GFN1Hamiltonian | GFN2Hamiltonian - """Instance of actual GFN Hamiltonian integral.""" - - __slots__ = ["integral"] - - def __init__( - self, - numbers: Tensor, - par: Param, - ihelp: IndexHelper, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - ): - super().__init__(device=device, dtype=dtype) - - if par.meta is not None: - if par.meta.name is not None: - if par.meta.name.casefold() == "gfn1-xtb": - self.integral = GFN1Hamiltonian( - numbers, par, ihelp, device=device, dtype=dtype - ) - elif par.meta.name.casefold() == "gfn2-xtb": - self.integral = GFN2Hamiltonian( - numbers, par, ihelp, device=device, dtype=dtype - ) - else: - raise ValueError(f"Unsupported Hamiltonian type: {par.meta.name}") diff --git a/src/dxtb/_src/integral/types/overlap.py b/src/dxtb/_src/integral/types/overlap.py index cc7d206dc..af22cd848 100644 --- a/src/dxtb/_src/integral/types/overlap.py +++ b/src/dxtb/_src/integral/types/overlap.py @@ -15,46 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Overlap -======= +Integral Types: Overlap +======================= The GFNn-xTB overlap matrix. """ from __future__ import annotations -import torch +from ..base import BaseIntegral -from dxtb._src.constants import labels -from dxtb._src.typing import TYPE_CHECKING, Any +__all__ = ["OverlapIntegral"] -from ..factory import new_overlap -from .base import BaseIntegral -if TYPE_CHECKING: - from ..driver.libcint import OverlapLibcint - from ..driver.pytorch import OverlapPytorch - -__all__ = ["Overlap"] - - -class Overlap(BaseIntegral): +class OverlapIntegral(BaseIntegral): """ - Overlap integral from atomic orbitals. + Base overlap class. """ - - integral: OverlapLibcint | OverlapPytorch - """Instance of actual overlap integral type.""" - - __slots__ = ["integral"] - - def __init__( - self, - driver: int = labels.INTDRIVER_LIBCINT, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - **kwargs: Any, - ) -> None: - super().__init__(device=device, dtype=dtype) - - self.integral = new_overlap(driver, device=device, dtype=dtype, **kwargs) diff --git a/src/dxtb/_src/integral/types/quadrupole.py b/src/dxtb/_src/integral/types/quadrupole.py index c8df189bc..6d4e8d0f5 100644 --- a/src/dxtb/_src/integral/types/quadrupole.py +++ b/src/dxtb/_src/integral/types/quadrupole.py @@ -15,54 +15,261 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +Integral Types: Quadrupole +========================== + Quadrupole integral. """ from __future__ import annotations import torch +from tad_mctc.math import einsum -from dxtb._src.constants import labels -from dxtb._src.typing import TYPE_CHECKING, Any - -from .base import BaseIntegral +from dxtb._src.typing import Tensor -if TYPE_CHECKING: - from ..driver.libcint import QuadrupoleLibcint +from ..base import BaseIntegral -__all__ = ["Quadrupole"] +__all__ = ["QuadrupoleIntegral"] -class Quadrupole(BaseIntegral): +class QuadrupoleIntegral(BaseIntegral): """ Quadrupole integral from atomic orbitals. """ - integral: QuadrupoleLibcint - """Instance of actual quadrupole integral type.""" - - def __init__( - self, - driver: int = labels.INTDRIVER_LIBCINT, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - **kwargs: Any, - ) -> None: - super().__init__(device=device, dtype=dtype) - - # Determine which overlap class to instantiate based on the type - if driver == labels.INTDRIVER_LIBCINT: - # pylint: disable=import-outside-toplevel - from ..driver.libcint import QuadrupoleLibcint - - if kwargs.pop("force_cpu_for_libcint", True): - device = torch.device("cpu") - - self.integral = QuadrupoleLibcint(device=device, dtype=dtype, **kwargs) - elif driver in (labels.INTDRIVER_ANALYTICAL, labels.INTDRIVER_AUTOGRAD): - raise NotImplementedError( - "PyTorch versions of multipole moments are not implemented. " - "Use `libcint` as integral driver." + def traceless(self) -> Tensor: + """ + Make a quadrupole (integral) tensor traceless. + + Parameters + ---------- + qpint : Tensor + Quadrupole moment tensor of shape ``(..., 9, n, n)``. + + Returns + ------- + Tensor + Traceless Quadrupole moment tensor of shape + ``(..., 6, n, n)``. + + Raises + ------ + RuntimeError + Supplied quadrupole integral is no ``3x3`` tensor. + + Note + ---- + First the quadrupole tensor is reshaped to be symmetric. + Due to symmetry, only the lower triangular matrix is used. + + xx xy xz 0 1 2 0 + yx yy yz <=> 3 4 5 -> 3 4 + zx zy zz 6 7 8 6 7 8 + """ + + if self.matrix.ndim not in (3, 4) or self.matrix.shape[-3] != 9: + raise RuntimeError( + "Quadrupole integral must be a tensor tensor of shape " + f"'(9, nao, nao)' but is {self.matrix.shape}." ) - else: - raise ValueError(f"Unknown integral driver '{driver}'.") + + # (..., 9, norb, norb) -> (..., 3, 3, norb, norb) + shp = self.matrix.shape + qpint = self.matrix.view(*shp[:-3], 3, 3, *shp[-2:]) + + # trace: (..., 3, 3, norb, norb) -> (..., norb, norb) + tr = 0.5 * einsum("...iijk->...jk", qpint) + + self.matrix = torch.stack( + [ + 1.5 * qpint[..., 0, 0, :, :] - tr, # xx + 1.5 * qpint[..., 1, 0, :, :], # yx + 1.5 * qpint[..., 1, 1, :, :] - tr, # yy + 1.5 * qpint[..., 2, 0, :, :], # zx + 1.5 * qpint[..., 2, 1, :, :], # zy + 1.5 * qpint[..., 2, 2, :, :] - tr, # zz + ], + dim=-3, + ) + return self.matrix + + def shift_r0r0_rjrj(self, r0: Tensor, overlap: Tensor, pos: Tensor) -> Tensor: + r""" + Shift the centering of the quadrupole integral (moment operator) from + the origin (:math:`r0 = r - (0, 0, 0)`) to atoms (ket index, + :math:`rj = r - r_j`). + + Create the shift contribution for all diagonal elements of the + quadrupole integral. + + We start with the quadrupole integral generated by the ``r0`` moment + operator: + + .. math:: + + Q_{xx}^{r0} = \langle i | (r_x - r0)^2 | j \rangle = \langle i | r_x^2 | j \rangle + + Now, we shift the integral to ``r_j`` yielding the quadrupole integral + center on the respective atoms: + + .. math:: + + \begin{align} + Q_{xx} &= \langle i | (r_x - r_{xj})^2 | j \rangle \\ + &= \langle i | r_x^2 | j \rangle - 2 \langle i | r_{xj} r_x | j \rangle + \langle i | r_{xj}^2 | j \rangle \\ + &= Q_{xx}^{r0} - 2 r_{xj} \langle i | r_x | j \rangle + r_{xj}^2 \langle i | j \rangle \\ + &= Q_{xx}^{r0} - 2 r_{xj} D_{x}^{r0} + r_{xj}^2 S_{ij} + \end{align} + + Next, we create the shift contribution for all off-diagonal elements of + the quadrupole integral. + + .. math:: + + \begin{align} + Q_{ab} &= \langle i | (r_a - r_{aj})(r_b - r_{bj}) | j \rangle \\ + &= \langle i | r_a r_b | j \rangle - \langle i | r_a r_{bj} | j \rangle - \langle i | r_{aj} r_b | j \rangle + \langle i | r_{aj} r_{bj} | j \rangle \\ + &= Q_{ab}^{r0} - r_{bj} \langle i | r_a | j \rangle - r_{aj} \langle i | r_b | j \rangle + r_{aj} r_{bj} \langle i | j \rangle \\ + &= Q_{ab}^{r0} - r_{bj} D_a^{r0} - r_{aj} D_b^{r0} + r_{aj} r_{bj} S_{ij} + \end{align} + + Parameters + ---------- + r0 : Tensor + Origin-centered dipole integral. + overlap : Tensor + Monopole integral (overlap). + pos : Tensor + Orbital-resolved atomic positions. + + Raises + ------ + RuntimeError + Shape mismatch between ``positions`` and ``overlap``. + The positions must be orbital-resolved. + + Returns + ------- + Tensor + Second-index (ket) atom-centered quadrupole integral. + """ + if pos.shape[-2] != overlap.shape[-1]: + raise RuntimeError( + "Shape mismatch between positions and overlap integral. " + "The position tensor must be spread to orbital-resolution." + ) + + # cartesian components for convenience + x = pos[..., 0] + y = pos[..., 1] + z = pos[..., 2] + dpx = r0[..., 0, :, :] + dpy = r0[..., 1, :, :] + dpz = r0[..., 2, :, :] + + # construct shift contribution from dipole and monopole (overlap) moments + shift_xx = shift_diagonal(x, dpx, overlap) + shift_yy = shift_diagonal(y, dpy, overlap) + shift_zz = shift_diagonal(z, dpz, overlap) + shift_yx = shift_offdiag(y, x, dpy, dpx, overlap) + shift_zx = shift_offdiag(z, x, dpz, dpx, overlap) + shift_zy = shift_offdiag(z, y, dpz, dpy, overlap) + + # collect the trace of shift contribution + tr = 0.5 * (shift_xx + shift_yy + shift_zz) + + self.matrix = torch.stack( + [ + self.matrix[..., 0, :, :] + 1.5 * shift_xx - tr, # xx + self.matrix[..., 1, :, :] + 1.5 * shift_yx, # yx + self.matrix[..., 2, :, :] + 1.5 * shift_yy - tr, # yy + self.matrix[..., 3, :, :] + 1.5 * shift_zx, # zx + self.matrix[..., 4, :, :] + 1.5 * shift_zy, # zy + self.matrix[..., 5, :, :] + 1.5 * shift_zz - tr, # zz + ], + dim=-3, + ) + return self.matrix + + +def shift_diagonal(c: Tensor, dpc: Tensor, s: Tensor) -> Tensor: + r""" + Create the shift contribution for all diagonal elements of the quadrupole + integral. + + We start with the quadrupole integral generated by the ``r0`` moment + operator: + + .. math:: + + Q_{xx}^{r0} = \langle i | (r_x - r0)^2 | j \rangle = \langle i | r_x^2 | j \rangle + + Now, we shift the integral to ``r_j`` yielding the quadrupole integral + center on the respective atoms: + + .. math:: + + \begin{align} + Q_{xx} &= \langle i | (r_x - r_{xj})^2 | j \rangle \\ + &= \langle i | r_x^2 | j \rangle - 2 \langle i | r_{xj} r_x | j \rangle + \langle i | r_{xj}^2 | j \rangle \\ + &= Q_{xx}^{r0} - 2 r_{xj} \langle i | r_x | j \rangle + r_{xj}^2 \langle i | j \rangle \\ + &= Q_{xx}^{r0} - 2 r_{xj} D_{x}^{r0} + r_{xj}^2 S_{ij} + \end{align} + + Parameters + ---------- + c : Tensor + Cartesian component. + dpc : Tensor + Cartesian component of dipole integral (`r0` operator). + s : Tensor + Overlap integral. + + Returns + ------- + Tensor + Shift contribution for diagonals of quadrupole integral. + """ + shift_1 = -2 * einsum("...j,...ij->...ij", c, dpc) + shift_2 = einsum("...j,...j,...ij->...ij", c, c, s) + return shift_1 + shift_2 + + +def shift_offdiag(a: Tensor, b: Tensor, dpa: Tensor, dpb: Tensor, s: Tensor) -> Tensor: + r""" + Create the shift contribution for all off-diagonal elements of the + quadrupole integral. + + .. math:: + + \begin{align} + Q_{ab} &= \langle i | (r_a - r_{aj})(r_b - r_{bj}) | j \rangle \\ + &= \langle i | r_a r_b | j \rangle - \langle i | r_a r_{bj} | j \rangle - \langle i | r_{aj} r_b | j \rangle + \langle i | r_{aj} r_{bj} | j \rangle \\ + &= Q_{ab}^{r0} - r_{bj} \langle i | r_a | j \rangle - r_{aj} \langle i | r_b | j \rangle + r_{aj} r_{bj} \langle i | j \rangle \\ + &= Q_{ab}^{r0} - r_{bj} D_a^{r0} - r_{aj} D_b^{r0} + r_{aj} r_{bj} S_{ij} + \end{align} + + Parameters + ---------- + a : Tensor + First cartesian component. + b : Tensor + Second cartesian component. + dpa : Tensor + First cartesian component of dipole integral (r0 operator). + dpb : Tensor + Second cartesian component of dipole integral (r0 operator). + s : Tensor + Overlap integral. + + Returns + ------- + Tensor + Shift contribution of off-diagonal elements of quadrupole integral. + """ + shift_ab_1 = -einsum("...j,...ij->...ij", b, dpa) + shift_ab_2 = -einsum("...j,...ij->...ij", a, dpb) + shift_ab_3 = einsum("...j,...j,...ij->...ij", a, b, s) + + return shift_ab_1 + shift_ab_2 + shift_ab_3 diff --git a/src/dxtb/_src/integral/utils.py b/src/dxtb/_src/integral/utils.py new file mode 100644 index 000000000..db00c899f --- /dev/null +++ b/src/dxtb/_src/integral/utils.py @@ -0,0 +1,34 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integrals: Utility Functions +============================ + +Integral-related utility functions. +""" + +from __future__ import annotations + +import torch + +from dxtb._src.typing import Tensor + + +def snorm(ovlp: Tensor) -> Tensor: + d = ovlp.diagonal(dim1=-1, dim2=-2) + zero = torch.tensor(0.0, dtype=d.dtype, device=d.device) + return torch.where(d == 0.0, zero, torch.pow(d, -0.5)) diff --git a/src/dxtb/_src/integral/wrappers.py b/src/dxtb/_src/integral/wrappers.py index fc72bbbea..6b97b742f 100644 --- a/src/dxtb/_src/integral/wrappers.py +++ b/src/dxtb/_src/integral/wrappers.py @@ -72,8 +72,8 @@ from dxtb._src.xtb.gfn1 import GFN1Hamiltonian from dxtb._src.xtb.gfn2 import GFN2Hamiltonian -from .factory import new_driver -from .types import Dipole, Overlap, Quadrupole +from .driver.manager import DriverManager +from .factory import new_dipint, new_overlap, new_quadint __all__ = ["hcore", "overlap", "dipint", "quadint"] @@ -98,20 +98,20 @@ def hcore(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tens Raises ------ - TypeError + ValueError If the Hamiltonian parametrization does not contain meta data or if the meta data does not contain a name to select the correct Hamiltonian. ValueError If the Hamiltonian name is unknown. """ if par.meta is None: - raise TypeError( + raise ValueError( "Meta data of Hamiltonian parametrization must contain a name. " "Otherwise, the correct Hamiltonian cannot be selected internally." ) if par.meta.name is None: - raise TypeError( + raise ValueError( "The name field of the meta data of the Hamiltonian " "parametrization must contain a name. Otherwise, the correct " "Hamiltonian cannot be selected internally." @@ -122,22 +122,14 @@ def hcore(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tens name = par.meta.name.casefold() if name == "gfn1-xtb": - h0 = GFN1Hamiltonian(numbers, par, ihelp, **dd) + h0 = GFN1Hamiltonian(numbers, par, ihelp, **dd, **kwargs) elif name == "gfn2-xtb": - h0 = GFN2Hamiltonian(numbers, par, ihelp, **dd) + h0 = GFN2Hamiltonian(numbers, par, ihelp, **dd, **kwargs) else: raise ValueError(f"Unknown Hamiltonian type '{name}'.") - # TODOGFN2: Handle possibly different CNs - cn = kwargs.pop("cn", None) - if cn is None: - # pylint: disable=import-outside-toplevel - from ..ncoord import cn_d3 - - cn = cn_d3(numbers, positions) - ovlp = overlap(numbers, positions, par) - return h0.build(positions, ovlp, cn=cn) + return h0.build(positions, ovlp.to(h0.device)) def overlap(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tensor: @@ -225,43 +217,70 @@ def _integral( Returns ------- Tensor - Integral matrix of shape ``(nb, nao, nao)`` for overlap and - ``(nb, 3, nao, nao)`` for dipole and quadrupole. + Integral matrix of shape ``(..., nao, nao)`` for overlap and + ``(..., 3, nao, nao)`` for dipole and quadrupole. Raises ------ ValueError If the integral type is unknown. """ + + if integral_type not in ("_overlap", "_dipole", "_quadrupole"): + raise ValueError(f"Unknown integral type '{integral_type}'.") + dd: DD = {"device": positions.device, "dtype": positions.dtype} + ihelp = IndexHelper.from_numbers(numbers, par) + + normalize = kwargs.pop("normalize", True) + + ########## + # Driver # + ########## # Determine which driver class to instantiate (defaults to libcint) driver_name = kwargs.pop("driver", labels.INTDRIVER_LIBCINT) - driver = new_driver(driver_name, numbers, par, **dd) # setup driver for integral calculation - driver.setup(positions) + drv_mgr = DriverManager(driver_name, **dd) + drv_mgr.create_driver(numbers, par, ihelp) + drv_mgr.driver.setup(positions) + + ########### + # Overlap # + ########### - # inject driver into requested integral if integral_type == "_overlap": - integral = Overlap(driver=driver_name, **dd, **kwargs) - elif integral_type in ("_dipole", "_quadrupole"): - ovlp = Overlap(driver=driver_name, **dd, **kwargs) - - # multipole integrals require the overlap for normalization - if ovlp.integral._matrix is None or ovlp.integral._norm is None: - ovlp.build(driver) - - if integral_type == "_dipole": - integral = Dipole(driver=driver_name, **dd, **kwargs) - elif integral_type == "_quadrupole": - integral = Quadrupole(driver=driver_name, **dd, **kwargs) - else: - raise ValueError(f"Unknown integral type '{integral_type}'.") - - integral.integral.norm = ovlp.integral.norm + integral = new_overlap(drv_mgr.driver_type, **dd, **kwargs) + + # actual integral calculation + integral.build(drv_mgr.driver) + + if normalize is True: + integral.normalize(integral.norm) + + return integral.matrix + + ############# + # Multipole # + ############# + + # multipole integrals require the overlap for normalization + ovlp = new_overlap(drv_mgr.driver_type, **dd, **kwargs) + if ovlp._matrix is None or ovlp.norm is None: + ovlp.build(drv_mgr.driver) + + if integral_type == "_dipole": + integral = new_dipint(driver=drv_mgr.driver_type, **dd, **kwargs) + elif integral_type == "_quadrupole": + integral = new_quadint(driver=drv_mgr.driver_type, **dd, **kwargs) else: raise ValueError(f"Unknown integral type '{integral_type}'.") # actual integral calculation - return integral.build(driver) + integral.build(drv_mgr.driver) + + if normalize is True: + integral.normalize(ovlp.norm) + + return integral.matrix diff --git a/src/dxtb/_src/typing/project.py b/src/dxtb/_src/typing/project.py index 26b82de54..2dd5e1797 100644 --- a/src/dxtb/_src/typing/project.py +++ b/src/dxtb/_src/typing/project.py @@ -23,11 +23,12 @@ from __future__ import annotations import torch +from tad_mctc.ncoord.typing import CNFunction, CNGradFunction from .builtin import TypedDict from .compat import Slicer -__all__ = ["ContainerData", "Slicers"] +__all__ = ["CNFunction", "CNGradFunction", "ContainerData", "Slicers"] class Slicers(TypedDict): diff --git a/src/dxtb/_src/utils/misc.py b/src/dxtb/_src/utils/misc.py index f69c86028..9bf7ebdb1 100644 --- a/src/dxtb/_src/utils/misc.py +++ b/src/dxtb/_src/utils/misc.py @@ -58,7 +58,7 @@ def get_all_slots(cls): if p.__name__ not in ("object", cls.__class__.__name__) ] - # and the hidden slots "__" and the "unit" slots + # and the hidden slots "__" and the "unit" slots parents_slots: list[str] = [ s for p in parents for s in p.__slots__ if "__" not in s ] diff --git a/src/dxtb/_src/xtb/abc.py b/src/dxtb/_src/xtb/abc.py new file mode 100644 index 000000000..d261d75e6 --- /dev/null +++ b/src/dxtb/_src/xtb/abc.py @@ -0,0 +1,94 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +xTB Hamiltonians: ABC +===================== + +Abstract case class for xTB Hamiltonians. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from dxtb._src.components.interactions import Potential +from dxtb._src.typing import Tensor + +__all__ = ["HamiltonianABC"] + + +class HamiltonianABC(ABC): + """ + Abstract base class for Hamiltonians. + """ + + @abstractmethod + def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: + """ + Build the xTB Hamiltonian. + + Parameters + ---------- + positions : Tensor + Atomic positions of molecular structure. + overlap : Tensor | None, optional + Overlap matrix. If ``None``, the true xTB Hamiltonian is *not* + built. Defaults to ``None``. + + Returns + ------- + Tensor + Hamiltonian (always symmetric). + """ + + @abstractmethod + def get_gradient( + self, + positions: Tensor, + overlap: Tensor, + doverlap: Tensor, + pmat: Tensor, + wmat: Tensor, + pot: Potential, + cn: Tensor, + ) -> tuple[Tensor, Tensor]: + """ + Calculate gradient of the full Hamiltonian with respect ot atomic positions. + + Parameters + ---------- + positions : Tensor + Atomic positions of molecular structure. + overlap : Tensor + Overlap matrix. + doverlap : Tensor + Derivative of the overlap matrix. + pmat : Tensor + Density matrix. + wmat : Tensor + Energy-weighted density. + pot : Tensor + Self-consistent electrostatic potential. + cn : Tensor + Coordination number. + + Returns + ------- + tuple[Tensor, Tensor] + Derivative of energy with respect to coordination number (first + tensor) and atomic positions (second tensor). + """ diff --git a/src/dxtb/_src/xtb/base.py b/src/dxtb/_src/xtb/base.py index 857178a7a..c982c4167 100644 --- a/src/dxtb/_src/xtb/base.py +++ b/src/dxtb/_src/xtb/base.py @@ -23,82 +23,15 @@ from __future__ import annotations -from abc import ABC, abstractmethod - import torch from dxtb import IndexHelper -from dxtb._src.components.interactions import Potential from dxtb._src.param import Param -from dxtb._src.typing import Tensor, TensorLike - -__all__ = ["HamiltonianABC", "BaseHamiltonian"] - - -class HamiltonianABC(ABC): - """ - Abstract base class for Hamiltonians. - """ - - @abstractmethod - def build( - self, positions: Tensor, overlap: Tensor, cn: Tensor | None = None - ) -> Tensor: - """ - Build the xTB Hamiltonian. +from dxtb._src.typing import CNFunction, PathLike, Tensor, TensorLike - Parameters - ---------- - positions : Tensor - Atomic positions of molecular structure. - overlap : Tensor - Overlap matrix. - cn : Tensor | None, optional - Coordination number. Defaults to ``None``. - - Returns - ------- - Tensor - Hamiltonian (always symmetric). - """ - - @abstractmethod - def get_gradient( - self, - positions: Tensor, - overlap: Tensor, - doverlap: Tensor, - pmat: Tensor, - wmat: Tensor, - pot: Potential, - cn: Tensor, - ) -> tuple[Tensor, Tensor]: - """ - Calculate gradient of the full Hamiltonian with respect ot atomic positions. +from .abc import HamiltonianABC - Parameters - ---------- - positions : Tensor - Atomic positions of molecular structure. - overlap : Tensor - Overlap matrix. - doverlap : Tensor - Derivative of the overlap matrix. - pmat : Tensor - Density matrix. - wmat : Tensor - Energy-weighted density. - pot : Tensor - Self-consistent electrostatic potential. - cn : Tensor - Coordination number. - - Returns - ------- - tuple[Tensor, Tensor] - Derivative of energy with respect to coordination number (first - tensor) and atomic positions (second tensor). - """ +__all__ = ["BaseHamiltonian"] class BaseHamiltonian(HamiltonianABC, TensorLike): @@ -142,6 +75,9 @@ class BaseHamiltonian(HamiltonianABC, TensorLike): rad: Tensor """Van-der-Waals radius of each species.""" + cn: None | CNFunction + """Coordination number function.""" + __slots__ = [ "numbers", "unique", @@ -189,6 +125,19 @@ def matrix(self) -> Tensor | None: def matrix(self, mat: Tensor) -> None: self._matrix = mat + def clear(self) -> None: + """ + Clear the integral matrix. + """ + self._matrix = None + + @property + def requires_grad(self) -> bool: + if self._matrix is None: + return False + + return self._matrix.requires_grad + def get_occupation(self) -> Tensor: """ Obtain the reference occupation numbers for each orbital. @@ -204,3 +153,18 @@ def get_occupation(self) -> Tensor: refocc / orb_per_shell, torch.tensor(0, **self.dd), ) + + def to_pt(self, path: PathLike | None = None) -> None: + """ + Save the integral matrix to a file. + + Parameters + ---------- + path : PathLike | None + Path to the file where the integral matrix should be saved. If + ``None``, the matrix is saved to the default location. + """ + if path is None: + path = f"{self.label.casefold()}.pt" + + torch.save(self.matrix, path) diff --git a/src/dxtb/_src/xtb/gfn1.py b/src/dxtb/_src/xtb/gfn1.py index d1485909b..5cbf95f17 100644 --- a/src/dxtb/_src/xtb/gfn1.py +++ b/src/dxtb/_src/xtb/gfn1.py @@ -28,12 +28,13 @@ from tad_mctc.batch import real_pairs from tad_mctc.convert import symmetrize from tad_mctc.data.radii import ATOMIC as ATOMIC_RADII +from tad_mctc.ncoord import cn_d3 from tad_mctc.units import EV2AU from dxtb import IndexHelper from dxtb._src.components.interactions import Potential from dxtb._src.param import Param -from dxtb._src.typing import Tensor +from dxtb._src.typing import Any, Tensor from .base import BaseHamiltonian @@ -53,7 +54,7 @@ def __init__( ihelp: IndexHelper, device: torch.device | None = None, dtype: torch.dtype | None = None, - **_, + **kwargs: Any, ) -> None: super().__init__(numbers, par, ihelp, device, dtype) @@ -79,6 +80,9 @@ def __init__( self.selfenergy = self.selfenergy * EV2AU self.kcn = self.kcn * EV2AU + # coordination number function + self.cn = kwargs.pop("cn", cn_d3) + # dtype should always be correct as it always uses self.dtype if any( tensor.dtype != self.dtype @@ -257,9 +261,7 @@ def _get_hscale(self) -> Tensor: return ksh - def build( - self, positions: Tensor, overlap: Tensor, cn: Tensor | None = None - ) -> Tensor: + def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: """ Build the xTB Hamiltonian. @@ -267,10 +269,9 @@ def build( ---------- positions : Tensor Atomic positions of molecular structure. - overlap : Tensor - Overlap matrix. - cn : Tensor | None, optional - Coordination number. Defaults to ``None``. + overlap : Tensor | None, optional + Overlap matrix. If ``None``, the true xTB Hamiltonian is *not* + built. Defaults to ``None``. Returns ------- @@ -294,8 +295,10 @@ def build( # ---------------- # Eq.29: H_(mu,mu) # ---------------- - if cn is None: + if self.cn is None: cn = torch.zeros_like(self.numbers, **self.dd) + else: + cn = self.cn(self.numbers, positions) kcn = self.ihelp.spread_ushell_to_shell(self.kcn) @@ -363,10 +366,12 @@ def build( ), dim=(-2, -1), ) - h = hcore * overlap + + if overlap is not None: + hcore = hcore * overlap # force symmetry to avoid problems through numerical errors - h0 = symmetrize(h) + h0 = symmetrize(hcore) self.matrix = h0 return h0 diff --git a/src/dxtb/_src/xtb/gfn2.py b/src/dxtb/_src/xtb/gfn2.py index 6cd4d7637..147099f01 100644 --- a/src/dxtb/_src/xtb/gfn2.py +++ b/src/dxtb/_src/xtb/gfn2.py @@ -36,9 +36,7 @@ class GFN2Hamiltonian(BaseHamiltonian): The GFN2-xTB Hamiltonian. """ - def build( - self, positions: Tensor, overlap: Tensor, cn: Tensor | None = None - ) -> Tensor: + def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: raise NotImplementedError("GFN2 not implemented yet.") def get_gradient( diff --git a/src/dxtb/calculators.py b/src/dxtb/calculators.py index 458369c29..280907c03 100644 --- a/src/dxtb/calculators.py +++ b/src/dxtb/calculators.py @@ -119,3 +119,12 @@ from dxtb._src.calculators.types import AutogradCalculator as AutogradCalculator from dxtb._src.calculators.types import EnergyCalculator as EnergyCalculator from dxtb._src.calculators.types import NumericalCalculator as NumericalCalculator + +__all__ = [ + "GFN1Calculator", + "GFN2Calculator", + "AnalyticalCalculator", + "AutogradCalculator", + "EnergyCalculator", + "NumericalCalculator", +] diff --git a/src/dxtb/components/base.py b/src/dxtb/components/base.py index 84da7232d..ee9a2a271 100644 --- a/src/dxtb/components/base.py +++ b/src/dxtb/components/base.py @@ -42,6 +42,7 @@ "ClassicalCache", # "Interaction", + "InteractionCache", "InteractionList", "InteractionListCache", ] diff --git a/src/dxtb/components/coulomb.py b/src/dxtb/components/coulomb.py index 4d6a85741..ab14a48bd 100644 --- a/src/dxtb/components/coulomb.py +++ b/src/dxtb/components/coulomb.py @@ -26,9 +26,4 @@ from dxtb._src.components.interactions.coulomb import new_es2 as new_es2 from dxtb._src.components.interactions.coulomb import new_es3 as new_es3 -__all__ = [ - "ES2", - "ES3", - "new_es2", - "new_es3", -] +__all__ = ["ES2", "ES3", "new_es2", "new_es3"] diff --git a/src/dxtb/components/dispersion.py b/src/dxtb/components/dispersion.py index ddc9d7162..fe70540a4 100644 --- a/src/dxtb/components/dispersion.py +++ b/src/dxtb/components/dispersion.py @@ -25,8 +25,4 @@ from dxtb._src.components.classicals.dispersion import DispersionD4 as DispersionD4 from dxtb._src.components.classicals.dispersion import new_dispersion as new_dispersion -__all__ = [ - "DispersionD3", - "DispersionD4", - "new_dispersion", -] +__all__ = ["DispersionD3", "DispersionD4", "new_dispersion"] diff --git a/src/dxtb/components/halogen.py b/src/dxtb/components/halogen.py index 29743913c..8dda02ea0 100644 --- a/src/dxtb/components/halogen.py +++ b/src/dxtb/components/halogen.py @@ -24,7 +24,4 @@ from dxtb._src.components.classicals.halogen import Halogen as Halogen from dxtb._src.components.classicals.halogen import new_halogen as new_halogen -__all__ = [ - "Halogen", - "new_halogen", -] +__all__ = ["Halogen", "new_halogen"] diff --git a/src/dxtb/components/repulsion.py b/src/dxtb/components/repulsion.py index 365cdd564..8b62ebd58 100644 --- a/src/dxtb/components/repulsion.py +++ b/src/dxtb/components/repulsion.py @@ -24,7 +24,4 @@ from dxtb._src.components.classicals.repulsion import Repulsion as Repulsion from dxtb._src.components.classicals.repulsion import new_repulsion as new_repulsion -__all__ = [ - "Repulsion", - "new_repulsion", -] +__all__ = ["Repulsion", "new_repulsion"] diff --git a/src/dxtb/components/solvation.py b/src/dxtb/components/solvation.py index a8de33310..d85113980 100644 --- a/src/dxtb/components/solvation.py +++ b/src/dxtb/components/solvation.py @@ -26,7 +26,4 @@ ) from dxtb._src.components.interactions.solvation import new_solvation as new_solvation -__all__ = [ - "GeneralizedBorn", - "new_solvation", -] +__all__ = ["GeneralizedBorn", "new_solvation"] diff --git a/src/dxtb/config.py b/src/dxtb/config.py index 8a3118fd8..d88240671 100644 --- a/src/dxtb/config.py +++ b/src/dxtb/config.py @@ -27,3 +27,12 @@ from dxtb._src.calculators.config.main import Config as Config from dxtb._src.calculators.config.scf import ConfigFermi as ConfigFermi from dxtb._src.calculators.config.scf import ConfigSCF as ConfigSCF + +__all__ = [ + "Config", + "ConfigCache", + "ConfigCacheStore", + "ConfigIntegrals", + "ConfigFermi", + "ConfigSCF", +] diff --git a/src/dxtb/integrals/__init__.py b/src/dxtb/integrals/__init__.py index 2bb712878..2d430ed12 100644 --- a/src/dxtb/integrals/__init__.py +++ b/src/dxtb/integrals/__init__.py @@ -52,6 +52,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from dxtb.integrals import factories as factories from dxtb.integrals import types as types from dxtb.integrals import wrappers as wrappers else: @@ -59,10 +60,11 @@ __getattr__, __dir__, __all__ = _lazy.attach_module( __name__, - ["types", "wrappers"], + ["factories", "types", "wrappers"], ) del _lazy del TYPE_CHECKING from dxtb._src.integral.container import Integrals as Integrals +from dxtb._src.integral.driver import DriverManager as DriverManager diff --git a/src/dxtb/integrals/factories.py b/src/dxtb/integrals/factories.py new file mode 100644 index 000000000..c289e5cea --- /dev/null +++ b/src/dxtb/integrals/factories.py @@ -0,0 +1,29 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integrals: Factory Functions +============================ + +Factory functions for integral classes. +""" + +from dxtb._src.integral.factory import new_dipint as new_dipint +from dxtb._src.integral.factory import new_hcore as new_hcore +from dxtb._src.integral.factory import new_overlap as new_overlap +from dxtb._src.integral.factory import new_quadint as new_quadint + +__all__ = ["new_dipint", "new_hcore", "new_overlap", "new_quadint"] diff --git a/src/dxtb/integrals/types.py b/src/dxtb/integrals/types.py index 9505653ed..875dd2049 100644 --- a/src/dxtb/integrals/types.py +++ b/src/dxtb/integrals/types.py @@ -21,7 +21,8 @@ Integral types for the calculation of molecular integrals. """ -from dxtb._src.integral.types import Dipole as Dipole -from dxtb._src.integral.types import HCore as HCore -from dxtb._src.integral.types import Overlap as Overlap -from dxtb._src.integral.types import Quadrupole as Quadrupole +from dxtb._src.integral.types import DipoleIntegral as DipoleIntegral +from dxtb._src.integral.types import OverlapIntegral as OverlapIntegral +from dxtb._src.integral.types import QuadrupoleIntegral as QuadrupoleIntegral + +__all__ = ["DipoleIntegral", "OverlapIntegral", "QuadrupoleIntegral"] diff --git a/src/dxtb/typing.py b/src/dxtb/typing.py index 4ab1567c9..610596ba0 100644 --- a/src/dxtb/typing.py +++ b/src/dxtb/typing.py @@ -26,3 +26,11 @@ from dxtb._src.typing import Slicers as Slicers from dxtb._src.typing import Tensor as Tensor from dxtb._src.typing import TensorLike as TensorLike + +__all__ = [ + "DD", + "Slicer", + "Slicers", + "Tensor", + "TensorLike", +] diff --git a/test/conftest.py b/test/conftest.py index 8003132aa..3ea83dea7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -38,6 +38,9 @@ DEVICE: torch.device | None = None """Name of Device.""" +NONDET_TOL = 1e-7 +"""Tolerance for non-deterministic tests.""" + # A bug in PyTorch 2.3.0 and 2.3.1 somehow requires manual import of # `torch._dynamo` to avoid errors with functorch in custom backward diff --git a/test/test_a_memory_leak/test_higher_deriv.py b/test/test_a_memory_leak/test_higher_deriv.py index 732ddeecf..257c47bb4 100644 --- a/test/test_a_memory_leak/test_higher_deriv.py +++ b/test/test_a_memory_leak/test_higher_deriv.py @@ -22,8 +22,6 @@ from __future__ import annotations -import gc - import pytest import torch from tad_mctc.data.molecules import mols as samples @@ -54,18 +52,18 @@ def fcn(): ihelp = IndexHelper.from_numbers(numbers, par) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) rep = new_repulsion(numbers, par, **dd) assert rep is not None cache = rep.get_cache(numbers, ihelp) - energy = rep.get_energy(positions, cache).sum() + energy = rep.get_energy(pos, cache).sum() - _ = nth_derivative(energy, positions, n) + _ = nth_derivative(energy, pos, n) del numbers - del positions + del pos del ihelp del rep del cache diff --git a/test/test_a_memory_leak/test_repulsion.py b/test/test_a_memory_leak/test_repulsion.py index 69858e40b..1e3d8698e 100644 --- a/test/test_a_memory_leak/test_repulsion.py +++ b/test/test_a_memory_leak/test_repulsion.py @@ -22,8 +22,6 @@ from __future__ import annotations -import gc - import pytest import torch from tad_mctc.data.molecules import mols as samples @@ -75,21 +73,19 @@ def fcn(): **dd, requires_grad=True, ) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) rep = Repulsion(arep, zeff, kexp, **dd) cache = rep.get_cache(numbers, ihelp) - energy = rep.get_energy(positions, cache).sum() - _ = torch.autograd.grad( - energy, (positions, arep, zeff, kexp), create_graph=True - ) + energy = rep.get_energy(pos, cache).sum() + _ = torch.autograd.grad(energy, (pos, arep, zeff, kexp), create_graph=True) # known reference cycle for create_graph=True energy.backward() del numbers - del positions + del pos del ihelp del rep del cache diff --git a/test/test_a_memory_leak/test_scf.py b/test/test_a_memory_leak/test_scf.py index 5eef810d6..1a6db572f 100644 --- a/test/test_a_memory_leak/test_scf.py +++ b/test/test_a_memory_leak/test_scf.py @@ -22,8 +22,6 @@ from __future__ import annotations -import gc - import pytest import torch from tad_mctc.data.molecules import mols as samples @@ -58,19 +56,19 @@ def fcn(): calc = Calculator(numbers, par, opts=options, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) - _ = torch.autograd.grad(energy, (positions), create_graph=create_graph) + _ = torch.autograd.grad(energy, (pos), create_graph=create_graph) # known reference cycle for create_graph=True if create_graph is True: energy.backward() del numbers - del positions + del pos del charges del calc del result @@ -101,11 +99,11 @@ def fcn(): calc = Calculator(numbers, par, opts=options, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - energy = calc.energy(positions, charges) + energy = calc.energy(pos, charges) - _ = torch.autograd.grad(energy, (positions), create_graph=create_graph) + _ = torch.autograd.grad(energy, (pos), create_graph=create_graph) # known reference cycle for create_graph=True if create_graph is True: @@ -140,19 +138,19 @@ def fcn(): calc = Calculator(numbers, par, opts=options, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) - _ = torch.autograd.grad(energy, (positions), create_graph=create_graph) + _ = torch.autograd.grad(energy, (pos), create_graph=create_graph) # known reference cycle for create_graph=True if create_graph is True: energy.backward() del numbers - del positions + del pos del charges del calc del result diff --git a/test/test_calculator/test_cache/test_integrals.py b/test/test_calculator/test_cache/test_integrals.py new file mode 100644 index 000000000..63d0aafbe --- /dev/null +++ b/test/test_calculator/test_cache/test_integrals.py @@ -0,0 +1,88 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test caching integrals. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb._src.typing import DD, Tensor +from dxtb.calculators import GFN1Calculator + +from ...conftest import DEVICE + +opts = {"cache_enabled": True, "verbosity": 0} + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_overlap_deleted(dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) + + calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd) + assert calc._ncalcs == 0 + + # overlap should not be cached + assert calc.opts.cache.store.overlap == False + + # one successful calculation + energy = calc.get_energy(positions) + assert calc._ncalcs == 1 + assert isinstance(energy, Tensor) + + # cache should be empty + assert calc.cache.overlap is None + + # ... but also the tensors in the calculator should be deleted + assert calc.integrals.overlap is not None + assert calc.integrals.overlap._matrix is None + assert calc.integrals.overlap._norm is None + assert calc.integrals.overlap._gradient is None + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_overlap_retained_for_grad(dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd, requires_grad=True + ) + + calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd) + assert calc._ncalcs == 0 + + # overlap should not be cached + assert calc.opts.cache.store.overlap == False + + # one successful calculation + energy = calc.get_energy(positions) + assert calc._ncalcs == 1 + assert isinstance(energy, Tensor) + + # cache should still be empty ... + assert calc.cache.overlap is None + + # ... but the tensors in the calculator should still be there + assert calc.integrals.overlap is not None + assert calc.integrals.overlap._matrix is not None + assert calc.integrals.overlap._norm is not None diff --git a/test/test_calculator/test_cache/test_properties.py b/test/test_calculator/test_cache/test_properties.py index d95dace90..03242de78 100644 --- a/test/test_calculator/test_cache/test_properties.py +++ b/test/test_calculator/test_cache/test_properties.py @@ -64,23 +64,23 @@ def test_forces(dtype: torch.dtype, grad_mode: Literal["functorch", "row"]) -> N numbers = torch.tensor([3, 1], device=DEVICE) positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"}) calc = AutogradCalculator(numbers, GFN1_XTB, opts=options, **dd) assert calc._ncalcs == 0 - prop = calc.get_forces(positions, grad_mode=grad_mode) + prop = calc.get_forces(pos, grad_mode=grad_mode) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for same calc - prop = calc.get_forces(positions, grad_mode=grad_mode) + prop = calc.get_forces(pos, grad_mode=grad_mode) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for energy - prop = calc.get_energy(positions) + prop = calc.get_energy(pos) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) @@ -95,22 +95,22 @@ def test_forces_analytical(dtype: torch.dtype) -> None: numbers = torch.tensor([3, 1], device=DEVICE) positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) calc = AnalyticalCalculator(numbers, GFN1_XTB, opts=opts, **dd) assert calc._ncalcs == 0 - prop = calc.get_forces(positions) + prop = calc.get_forces(pos) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for same calc - prop = calc.get_forces(positions) + prop = calc.get_forces(pos) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for energy - prop = calc.get_energy(positions) + prop = calc.get_energy(pos) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) @@ -126,31 +126,31 @@ def test_hessian(dtype: torch.dtype, use_functorch: bool) -> None: numbers = torch.tensor([3, 1], device=DEVICE) positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"}) calc = AutogradCalculator(numbers, GFN1_XTB, opts=options, **dd) assert calc._ncalcs == 0 - prop = calc.get_hessian(positions, use_functorch=use_functorch) + prop = calc.get_hessian(pos, use_functorch=use_functorch) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for same calc assert "hessian" in calc.cache.list_cached_properties() - prop = calc.get_hessian(positions, use_functorch=use_functorch) + prop = calc.get_hessian(pos, use_functorch=use_functorch) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for energy assert "energy" in calc.cache.list_cached_properties() - prop = calc.get_energy(positions) + prop = calc.get_energy(pos) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for forces (needs `functorch` to be equivalent) assert "forces" in calc.cache.list_cached_properties() - prop = calc.get_forces(positions, grad_mode="functorch") + prop = calc.get_forces(pos, grad_mode="functorch") assert calc._ncalcs == 1 assert isinstance(prop, Tensor) @@ -166,38 +166,38 @@ def test_vibration(dtype: torch.dtype, use_functorch: bool) -> None: numbers = torch.tensor([3, 1], device=DEVICE) positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"}) calc = AutogradCalculator(numbers, GFN1_XTB, opts=options, **dd) assert calc._ncalcs == 0 - prop = calc.get_normal_modes(positions, use_functorch=use_functorch) + prop = calc.get_normal_modes(pos, use_functorch=use_functorch) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) assert "normal_modes" in calc.cache.list_cached_properties() # cache is used for freqs assert "frequencies" in calc.cache.list_cached_properties() - prop = calc.get_frequencies(positions, use_functorch=use_functorch) + prop = calc.get_frequencies(pos, use_functorch=use_functorch) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for full vibration result assert "vibration" in calc.cache.list_cached_properties() - prop = calc.get_vibration(positions, use_functorch=use_functorch) + prop = calc.get_vibration(pos, use_functorch=use_functorch) assert calc._ncalcs == 1 assert isinstance(prop, VibResult) # cache is used for forces (needs `functorch` to be equivalent) assert "forces" in calc.cache.list_cached_properties() grad_mode = "autograd" if use_functorch is False else "functorch" - prop = calc.get_forces(positions, grad_mode=grad_mode) + prop = calc.get_forces(pos, grad_mode=grad_mode) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) # cache is used for hessian (needs `matrix=False` bo be equivalent) assert "hessian" in calc.cache.list_cached_properties() - prop = calc.get_hessian(positions, use_functorch=use_functorch, matrix=False) + prop = calc.get_hessian(pos, use_functorch=use_functorch, matrix=False) assert calc._ncalcs == 1 assert isinstance(prop, Tensor) diff --git a/test/test_calculator/test_general.py b/test/test_calculator/test_general.py index 9131b0ad3..2317a9f5a 100644 --- a/test/test_calculator/test_general.py +++ b/test/test_calculator/test_general.py @@ -53,8 +53,6 @@ def run_asserts(c: Calculator, dtype: torch.dtype) -> None: assert c.integrals.dtype == dtype - assert c.integrals.driver.dtype == dtype - def test_change_type() -> None: numbers = torch.tensor([6, 1, 1, 1, 1]) @@ -77,11 +75,12 @@ def test_change_type_after_energy() -> None: run_asserts(calc_64, dtype) - # extra asserts on initialized vars - assert calc_64.integrals.driver.basis.dtype == dtype - assert calc_64.integrals.driver.basis.ngauss.dtype == torch.uint8 - assert calc_64.integrals.driver.basis.pqn.dtype == torch.uint8 - assert calc_64.integrals.driver.basis.slater.dtype == dtype + # extra asserts on initialized + bas = calc_64.integrals.mgr.driver.basis + assert bas.dtype == dtype + assert bas.ngauss.dtype == torch.uint8 + assert bas.pqn.dtype == torch.uint8 + assert bas.slater.dtype == dtype assert calc_64.integrals.hcore is not None assert calc_64.integrals.hcore.dtype == dtype diff --git a/test/test_coulomb/test_es2_atom.py b/test/test_coulomb/test_es2_atom.py index 1af6f0474..39e2e7cd6 100644 --- a/test/test_coulomb/test_es2_atom.py +++ b/test/test_coulomb/test_es2_atom.py @@ -34,7 +34,7 @@ from dxtb._src.param.utils import get_elem_param from dxtb._src.typing import DD, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["MB16_43_01", "MB16_43_02", "SiH4_atom"] @@ -121,13 +121,13 @@ def test_grad_positions(name: str) -> None: assert es is not None # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(positions: Tensor): - cache = es.get_cache(numbers, positions, ihelp) + def func(p: Tensor): + cache = es.get_cache(numbers, p, ihelp) return es.get_atom_energy(qat, cache) - assert dgradcheck(func, positions) + assert dgradcheck(func, pos, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -157,4 +157,4 @@ def func(gexp: Tensor, hubbard: Tensor): cache = es.get_cache(numbers, positions, ihelp) return es.get_atom_energy(qat, cache) - assert dgradcheck(func, (gexp, hubbard)) + assert dgradcheck(func, (gexp, hubbard), nondet_tol=NONDET_TOL) diff --git a/test/test_coulomb/test_es2_general.py b/test/test_coulomb/test_es2_general.py index 74a9b6cfc..5ac651835 100644 --- a/test/test_coulomb/test_es2_general.py +++ b/test/test_coulomb/test_es2_general.py @@ -73,8 +73,8 @@ def test_grad_fail() -> None: assert (torch.zeros_like(positions) == grad).all() with pytest.raises(RuntimeError): - positions.requires_grad_(False) - es._gradient(energy, positions) + pos = positions.clone().requires_grad_(False) + es._gradient(energy, pos) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) diff --git a/test/test_coulomb/test_es2_shell.py b/test/test_coulomb/test_es2_shell.py index f96bc6a19..6c7299ca9 100644 --- a/test/test_coulomb/test_es2_shell.py +++ b/test/test_coulomb/test_es2_shell.py @@ -34,7 +34,7 @@ from dxtb._src.param.utils import get_elem_param from dxtb._src.typing import DD, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["MB16_43_07", "MB16_43_08", "SiH4"] @@ -119,17 +119,17 @@ def test_grad_positions(name: str) -> None: ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(positions: Tensor): + def func(p: Tensor): es = es2.new_es2(numbers, GFN1_XTB, shell_resolved=False, **dd) if es is None: assert False - cache = es.get_cache(numbers, positions, ihelp) + cache = es.get_cache(numbers, p, ihelp) return es.get_shell_energy(qsh, cache) - assert dgradcheck(func, positions) + assert dgradcheck(func, pos, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -160,4 +160,4 @@ def func(gexp: Tensor, hubbard: Tensor): cache = es.get_cache(numbers, positions, ihelp) return es.get_shell_energy(qsh, cache) - assert dgradcheck(func, (gexp, hubbard)) + assert dgradcheck(func, (gexp, hubbard), nondet_tol=NONDET_TOL) diff --git a/test/test_coulomb/test_es3.py b/test/test_coulomb/test_es3.py index 36d4a834f..b04cc6ac2 100644 --- a/test/test_coulomb/test_es3.py +++ b/test/test_coulomb/test_es3.py @@ -31,7 +31,7 @@ from dxtb._src.param.utils import get_elem_param from dxtb._src.typing import DD, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["MB16_43_01", "MB16_43_02", "SiH4_atom"] @@ -120,4 +120,4 @@ def func(hubbard_derivs: Tensor): cache = es.get_cache(numbers=numbers, ihelp=ihelp) return es.get_atom_energy(qat, cache) - assert dgradcheck(func, hd) + assert dgradcheck(func, hd, nondet_tol=NONDET_TOL) diff --git a/test/test_coulomb/test_grad_atom.py b/test/test_coulomb/test_grad_atom.py index d0a77edcf..b7ee647c9 100644 --- a/test/test_coulomb/test_grad_atom.py +++ b/test/test_coulomb/test_grad_atom.py @@ -70,22 +70,20 @@ def test_single(dtype: torch.dtype, name: str) -> None: assert pytest.approx(num_grad.cpu(), abs=tol) == grad.cpu() # automatic - positions.requires_grad_(True) - mat = es.get_atom_coulomb_matrix(numbers, positions, ihelp) + pos = positions.clone().requires_grad_(True) + mat = es.get_atom_coulomb_matrix(numbers, pos, ihelp) energy = 0.5 * mat * charges.unsqueeze(-1) * charges.unsqueeze(-2) - (agrad,) = torch.autograd.grad(energy.sum(), positions) + (agrad,) = torch.autograd.grad(energy.sum(), pos) assert pytest.approx(ref.cpu(), abs=tol) == agrad.cpu() # analytical (automatic) es.cache_invalidate() - cache = es.get_cache(numbers, positions, ihelp) # recalc with gradients - egrad = es.get_atom_gradient(charges, positions, cache) + cache = es.get_cache(numbers, pos, ihelp) # recalc with gradients + egrad = es.get_atom_gradient(charges, pos, cache) egrad.detach_() assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu() assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu() - positions.detach_() - @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name1", ["SiH4_atom"]) @@ -132,21 +130,19 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu() # automatic - positions.requires_grad_(True) - mat = es.get_atom_coulomb_matrix(numbers, positions, ihelp) + pos = positions.clone().requires_grad_(True) + mat = es.get_atom_coulomb_matrix(numbers, pos, ihelp) energy = 0.5 * mat * charges.unsqueeze(-1) * charges.unsqueeze(-2) - (agrad,) = torch.autograd.grad(energy.sum(), positions) + (agrad,) = torch.autograd.grad(energy.sum(), pos) assert pytest.approx(ref.cpu(), abs=tol) == agrad.cpu() # analytical (automatic) - cache = es.get_cache(numbers, positions, ihelp) # recalc with gradients - egrad = es.get_atom_gradient(charges, positions, cache) + cache = es.get_cache(numbers, pos, ihelp) # recalc with gradients + egrad = es.get_atom_gradient(charges, pos, cache) egrad.detach_() assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu() assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu() - positions.detach_() - def calc_numerical_gradient( numbers: Tensor, positions: Tensor, ihelp: IndexHelper, charges: Tensor diff --git a/test/test_coulomb/test_grad_atom_param.py b/test/test_coulomb/test_grad_atom_param.py index bad4925fa..7ec85b9df 100644 --- a/test/test_coulomb/test_grad_atom_param.py +++ b/test/test_coulomb/test_grad_atom_param.py @@ -31,7 +31,7 @@ from dxtb._src.param import get_elem_param from dxtb._src.typing import DD, Callable, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["LiH", "SiH4", "MB16_43_01"] @@ -81,7 +81,7 @@ def test_grad_param(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_param(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -93,7 +93,7 @@ def test_gradgrad_param(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradcheck_param(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) def gradcheck_param_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ @@ -157,7 +157,7 @@ def test_grad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> None: # same for both values. diffvars[0].requires_grad_(False) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -177,4 +177,4 @@ def test_gradgrad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> Non # same for both values. diffvars[0].requires_grad_(False) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) diff --git a/test/test_coulomb/test_grad_atom_pos.py b/test/test_coulomb/test_grad_atom_pos.py index 49fc2d9f9..593257cd1 100644 --- a/test/test_coulomb/test_grad_atom_pos.py +++ b/test/test_coulomb/test_grad_atom_pos.py @@ -31,7 +31,7 @@ from dxtb._src.param import get_elem_param from dxtb._src.typing import DD, Callable, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["LiH", "SiH4", "MB16_43_01"] @@ -65,15 +65,15 @@ def gradcheck_pos( gexp = torch.tensor(par.charge.effective.gexp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) es2 = ES2(hubbard, None, gexp=gexp, shell_resolved=False, **dd) es2.cache_disable() - def func(positions: Tensor) -> Tensor: - return es2.get_atom_coulomb_matrix(numbers, positions, ihelp) + def func(p: Tensor) -> Tensor: + return es2.get_atom_coulomb_matrix(numbers, p, ihelp) - return func, positions + return func, pos @pytest.mark.grad @@ -85,7 +85,7 @@ def test_grad_pos(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_pos(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -97,7 +97,7 @@ def test_gradgrad_pos(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradcheck_pos(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) def gradcheck_pos_batch( @@ -136,14 +136,14 @@ def gradcheck_pos_batch( gexp = torch.tensor(par.charge.effective.gexp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) es2 = ES2(hubbard, None, gexp=gexp, shell_resolved=False, **dd) - def func(positions: Tensor) -> Tensor: - return es2.get_atom_coulomb_matrix(numbers, positions, ihelp) + def func(p: Tensor) -> Tensor: + return es2.get_atom_coulomb_matrix(numbers, p, ihelp) - return func, positions + return func, pos @pytest.mark.grad @@ -156,7 +156,7 @@ def test_grad_pos_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_pos_batch(dtype, name1, name2) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -169,4 +169,4 @@ def test_gradgrad_pos_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradcheck_pos_batch(dtype, name1, name2) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) diff --git a/test/test_coulomb/test_grad_shell.py b/test/test_coulomb/test_grad_shell.py index f788c3575..eefbf69bd 100644 --- a/test/test_coulomb/test_grad_shell.py +++ b/test/test_coulomb/test_grad_shell.py @@ -69,21 +69,19 @@ def test_single(dtype: torch.dtype, name: str) -> None: assert pytest.approx(num_grad.cpu(), abs=tol) == grad.cpu() # automatic - positions.requires_grad_(True) - mat = es.get_shell_coulomb_matrix(numbers, positions, ihelp) + pos = positions.clone().requires_grad_(True) + mat = es.get_shell_coulomb_matrix(numbers, pos, ihelp) energy = 0.5 * mat * charges.unsqueeze(-1) * charges.unsqueeze(-2) - (agrad,) = torch.autograd.grad(energy.sum(), positions) + (agrad,) = torch.autograd.grad(energy.sum(), pos) assert pytest.approx(ref.cpu(), abs=tol) == agrad.cpu() # analytical (automatic) - cache = es.get_cache(numbers, positions, ihelp) # recalc with gradients - egrad = es.get_shell_gradient(charges, positions, cache) + cache = es.get_cache(numbers, pos, ihelp) # recalc with gradients + egrad = es.get_shell_gradient(charges, pos, cache) egrad.detach_() assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu() assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu() - positions.detach_() - @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name1", ["SiH4"]) @@ -131,21 +129,19 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: assert pytest.approx(grad.cpu(), abs=tol) == ref.cpu() # automatic - positions.requires_grad_(True) - mat = es.get_shell_coulomb_matrix(numbers, positions, ihelp) + pos = positions.clone().requires_grad_(True) + mat = es.get_shell_coulomb_matrix(numbers, pos, ihelp) energy = 0.5 * mat * charges.unsqueeze(-1) * charges.unsqueeze(-2) - (agrad,) = torch.autograd.grad(energy.sum(), positions) + (agrad,) = torch.autograd.grad(energy.sum(), pos) assert pytest.approx(ref.cpu(), abs=tol) == agrad.cpu() # analytical (automatic) - cache = es.get_cache(numbers, positions, ihelp) # recalc with gradients - egrad = es.get_shell_gradient(charges, positions, cache) + cache = es.get_cache(numbers, pos, ihelp) # recalc with gradients + egrad = es.get_shell_gradient(charges, pos, cache) egrad.detach_() assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu() assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu() - positions.detach_() - def calc_numerical_gradient( numbers: Tensor, positions: Tensor, ihelp: IndexHelper, charges: Tensor diff --git a/test/test_coulomb/test_grad_shell_param.py b/test/test_coulomb/test_grad_shell_param.py index b9d678f09..1247ba186 100644 --- a/test/test_coulomb/test_grad_shell_param.py +++ b/test/test_coulomb/test_grad_shell_param.py @@ -31,7 +31,7 @@ from dxtb._src.param import get_elem_param from dxtb._src.typing import DD, Callable, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["LiH", "SiH4"] # "MB16_43_01" requires a lot of RAM @@ -90,7 +90,7 @@ def test_grad_param(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_param(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -103,7 +103,7 @@ def test_grad_param_large(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_param(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -115,7 +115,7 @@ def test_gradgrad_param(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradcheck_param(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -128,7 +128,7 @@ def test_gradgrad_param_large(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_param(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) def gradcheck_param_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ @@ -194,7 +194,7 @@ def test_grad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradcheck_param_batch(dtype, name1, name2) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -207,4 +207,4 @@ def test_gradgrad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> Non gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradcheck_param_batch(dtype, name1, name2) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) diff --git a/test/test_coulomb/test_grad_shell_pos.py b/test/test_coulomb/test_grad_shell_pos.py index 5f4cedaab..ba9fdbabd 100644 --- a/test/test_coulomb/test_grad_shell_pos.py +++ b/test/test_coulomb/test_grad_shell_pos.py @@ -31,7 +31,7 @@ from dxtb._src.param import get_elem_param from dxtb._src.typing import DD, Callable, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples sample_list = ["LiH", "SiH4"] @@ -72,15 +72,15 @@ def gradchecker( gexp = torch.tensor(par.charge.effective.gexp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) es2 = ES2(hubbard, lhubbard, gexp=gexp, shell_resolved=True, **dd) es2.cache_disable() - def func(positions: Tensor) -> Tensor: - return es2.get_shell_coulomb_matrix(numbers, positions, ihelp) + def func(p: Tensor) -> Tensor: + return es2.get_shell_coulomb_matrix(numbers, p, ihelp) - return func, positions + return func, pos @pytest.mark.grad @@ -92,7 +92,7 @@ def test_grad(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -105,7 +105,7 @@ def test_grad_large(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradcheck(func, diffvars, atol=tol, fast_mode=True) + assert dgradcheck(func, diffvars, atol=tol, fast_mode=True, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -118,7 +118,7 @@ def test_gradgrad(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -130,7 +130,9 @@ def test_gradgrad_large(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=True) + assert dgradgradcheck( + func, diffvars, atol=tol, fast_mode=True, nondet_tol=NONDET_TOL + ) def gradchecker_batch( @@ -176,14 +178,14 @@ def gradchecker_batch( gexp = torch.tensor(par.charge.effective.gexp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) es2 = ES2(hubbard, lhubbard, gexp=gexp, shell_resolved=True, **dd) - def func(positions: Tensor) -> Tensor: - return es2.get_shell_coulomb_matrix(numbers, positions, ihelp) + def func(p: Tensor) -> Tensor: + return es2.get_shell_coulomb_matrix(numbers, p, ihelp) - return func, positions + return func, pos @pytest.mark.grad @@ -196,7 +198,7 @@ def test_grad_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert dgradcheck(func, diffvars, atol=tol, nondet_tol=1e-7) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -209,4 +211,4 @@ def test_gradgrad_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=1e-7) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) diff --git a/test/test_dispersion/test_d3.py b/test/test_dispersion/test_d3.py index 352529d1a..fc973f505 100644 --- a/test/test_dispersion/test_d3.py +++ b/test/test_dispersion/test_d3.py @@ -114,7 +114,7 @@ def test_grad_pos() -> None: positions = sample["positions"].to(**dd).detach().clone() # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, **dd) if disp is None: @@ -122,13 +122,13 @@ def test_grad_pos() -> None: cache = disp.get_cache(numbers) - def func(positions: Tensor) -> Tensor: - return disp.get_energy(positions, cache) + def func(p: Tensor) -> Tensor: + return disp.get_energy(p, cache) # pylint: disable=import-outside-toplevel from torch.autograd.gradcheck import gradcheck - assert gradcheck(func, positions) + assert gradcheck(func, pos) @pytest.mark.grad @@ -140,7 +140,6 @@ def test_grad_pos_tblite(dtype: torch.dtype) -> None: sample = samples["PbH4-BiH3"] numbers = sample["numbers"].to(DEVICE) positions = sample["positions"].to(**dd).detach().clone() - positions.requires_grad_(True) ref = sample["grad"].to(**dd) disp = new_dispersion(numbers, par, **dd) @@ -149,13 +148,16 @@ def test_grad_pos_tblite(dtype: torch.dtype) -> None: cache = disp.get_cache(numbers) + # variable to be differentiated + pos = positions.clone().requires_grad_(True) + # automatic gradient - energy = torch.sum(disp.get_energy(positions, cache), dim=-1) + energy = torch.sum(disp.get_energy(pos, cache), dim=-1) energy.backward() - if positions.grad is None: + if pos.grad is None: assert False - grad_backward = positions.grad.clone() + grad_backward = pos.grad.clone() assert pytest.approx(grad_backward.cpu(), abs=1e-10) == ref.cpu() diff --git a/test/test_dispersion/test_d4.py b/test/test_dispersion/test_d4.py index cd0fc389a..f27ec776a 100644 --- a/test/test_dispersion/test_d4.py +++ b/test/test_dispersion/test_d4.py @@ -100,7 +100,7 @@ def test_grad_pos() -> None: charge = positions.new_tensor(0.0) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, charge, **dd) if disp is None: @@ -108,13 +108,13 @@ def test_grad_pos() -> None: cache = disp.get_cache(numbers) - def func(positions: Tensor) -> Tensor: - return disp.get_energy(positions, cache) + def func(p: Tensor) -> Tensor: + return disp.get_energy(p, cache) # pylint: disable=import-outside-toplevel from torch.autograd.gradcheck import gradcheck - assert gradcheck(func, positions) + assert gradcheck(func, pos) @pytest.mark.grad diff --git a/test/test_dispersion/test_grad_pos.py b/test/test_dispersion/test_grad_pos.py index 8d353dda1..db6327d9b 100644 --- a/test/test_dispersion/test_grad_pos.py +++ b/test/test_dispersion/test_grad_pos.py @@ -49,17 +49,17 @@ def gradchecker(dtype: torch.dtype, name: str) -> tuple[ positions = sample["positions"].to(**dd) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, **dd) assert disp is not None cache = disp.get_cache(numbers) - def func(positions: Tensor) -> Tensor: - return disp.get_energy(positions, cache) + def func(p: Tensor) -> Tensor: + return disp.get_energy(p, cache) - return func, positions + return func, pos @pytest.mark.grad @@ -133,17 +133,17 @@ def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, **dd) assert disp is not None cache = disp.get_cache(numbers) - def func(positions: Tensor) -> Tensor: - return disp.get_energy(positions, cache) + def func(p: Tensor) -> Tensor: + return disp.get_energy(p, cache) - return func, positions + return func, pos @pytest.mark.grad @@ -208,7 +208,6 @@ def test_autograd(dtype: torch.dtype, name: str) -> None: sample = samples[name] numbers = sample["numbers"].to(DEVICE) positions = sample["positions"].to(**dd) - positions.requires_grad_(True) ref = sample["grad"].to(**dd) disp = new_dispersion(numbers, par, **dd) @@ -216,12 +215,13 @@ def test_autograd(dtype: torch.dtype, name: str) -> None: cache = disp.get_cache(numbers) - energy = disp.get_energy(positions, cache) - grad_autograd = disp.get_gradient(energy, positions) + # variable to be differentiated + pos = positions.clone().requires_grad_(True) - assert pytest.approx(ref.cpu(), abs=tol) == grad_autograd.detach().cpu() + energy = disp.get_energy(pos, cache) + grad_autograd = disp.get_gradient(energy, pos) - positions.detach_() + assert pytest.approx(ref.cpu(), abs=tol) == grad_autograd.detach().cpu() @pytest.mark.grad @@ -251,17 +251,18 @@ def test_autograd_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ] ) - positions.requires_grad_(True) + # variable to be differentiated + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, **dd) assert disp is not None cache = disp.get_cache(numbers) - energy = disp.get_energy(positions, cache) - grad_autograd = disp.get_gradient(energy, positions) + energy = disp.get_energy(pos, cache) + grad_autograd = disp.get_gradient(energy, pos) - positions.detach_() + pos.detach_() grad_autograd.detach_() assert pytest.approx(ref.cpu(), abs=tol) == grad_autograd.cpu() diff --git a/test/test_dispersion/test_hess.py b/test/test_dispersion/test_hess.py index e6ca03836..62be3e00b 100644 --- a/test/test_dispersion/test_hess.py +++ b/test/test_dispersion/test_hess.py @@ -54,7 +54,7 @@ def test_single(dtype: torch.dtype, name: str) -> None: numref = _numhess(numbers, positions) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, **dd) assert disp is not None @@ -64,10 +64,10 @@ def test_single(dtype: torch.dtype, name: str) -> None: def energy(pos: Tensor) -> Tensor: return disp.get_energy(pos, cache).sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() + pos.detach_() hess = hess.reshape_as(ref) assert ref.shape == numref.shape == hess.shape @@ -110,7 +110,7 @@ def skip_test_batch(dtype: torch.dtype, name1: str, name2) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) disp = new_dispersion(numbers, par, **dd) assert disp is not None @@ -120,12 +120,12 @@ def skip_test_batch(dtype: torch.dtype, name1: str, name2) -> None: def energy(pos: Tensor) -> Tensor: return disp.get_energy(pos, cache).sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == hess.detach().cpu() - positions.detach_() + pos.detach_() def _numhess(numbers: Tensor, positions: Tensor) -> Tensor: @@ -139,11 +139,11 @@ def _numhess(numbers: Tensor, positions: Tensor) -> Tensor: hess = torch.zeros(*(*positions.shape, *positions.shape), **dd) step = 1.0e-4 - def _gradfcn(positions: Tensor) -> Tensor: - positions.requires_grad_(True) - energy = disp.get_energy(positions, cache) - gradient = disp.get_gradient(energy, positions) - positions.detach_() + def _gradfcn(pos: Tensor) -> Tensor: + pos.requires_grad_(True) + energy = disp.get_energy(pos, cache) + gradient = disp.get_gradient(energy, pos) + pos.detach_() return gradient.detach() for i in range(numbers.shape[0]): diff --git a/test/test_external/test_field.py b/test/test_external/test_field.py index d7db0ad0d..5b3552cac 100644 --- a/test/test_external/test_field.py +++ b/test/test_external/test_field.py @@ -64,6 +64,8 @@ def test_single(dtype: torch.dtype, name: str) -> None: calc = Calculator(numbers, GFN1_XTB, interaction=[efield], opts=opts, **dd) result = calc.singlepoint(positions, charges) + print(result.total) + res = result.total.sum(-1) assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == res.cpu() diff --git a/test/test_halogen/test_grad_pos.py b/test/test_halogen/test_grad_pos.py index 23c290ee9..fdbdb5f18 100644 --- a/test/test_halogen/test_grad_pos.py +++ b/test/test_halogen/test_grad_pos.py @@ -51,7 +51,7 @@ def gradchecker(dtype: torch.dtype, name: str) -> tuple[ positions = sample["positions"].to(**dd) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -59,10 +59,10 @@ def gradchecker(dtype: torch.dtype, name: str) -> tuple[ ihelp = IndexHelper.from_numbers(numbers, par) cache = xb.get_cache(numbers, ihelp) - def func(pos: Tensor) -> Tensor: - return xb.get_energy(pos, cache) + def func(p: Tensor) -> Tensor: + return xb.get_energy(p, cache) - return func, positions + return func, pos @pytest.mark.grad @@ -110,7 +110,7 @@ def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -118,10 +118,10 @@ def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ ihelp = IndexHelper.from_numbers(numbers, par) cache = xb.get_cache(numbers, ihelp) - def func(pos: Tensor) -> Tensor: - return xb.get_energy(pos, cache) + def func(p: Tensor) -> Tensor: + return xb.get_energy(p, cache) - return func, positions + return func, pos @pytest.mark.grad @@ -162,7 +162,7 @@ def test_autograd(dtype: torch.dtype, name: str) -> None: ref = sample["gradient"].to(**dd) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -170,10 +170,10 @@ def test_autograd(dtype: torch.dtype, name: str) -> None: ihelp = IndexHelper.from_numbers(numbers, par) cache = xb.get_cache(numbers, ihelp) - energy = xb.get_energy(positions, cache) - grad_autograd = xb.get_gradient(energy, positions) + energy = xb.get_energy(pos, cache) + grad_autograd = xb.get_gradient(energy, pos) - positions.detach_() + pos.detach_() grad_autograd.detach_() assert pytest.approx(ref.cpu(), abs=tol * 10) == grad_autograd.cpu() @@ -207,7 +207,7 @@ def test_autograd_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -215,10 +215,10 @@ def test_autograd_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ihelp = IndexHelper.from_numbers(numbers, par) cache = xb.get_cache(numbers, ihelp) - energy = xb.get_energy(positions, cache) - grad_autograd = xb.get_gradient(energy, positions) + energy = xb.get_energy(pos, cache) + grad_autograd = xb.get_gradient(energy, pos) - positions.detach_() + pos.detach_() grad_autograd.detach_() assert pytest.approx(ref.cpu(), abs=tol * 10) == grad_autograd.cpu() @@ -237,7 +237,7 @@ def test_backward(dtype: torch.dtype, name: str) -> None: ref = sample["gradient"].to(**dd) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -245,15 +245,15 @@ def test_backward(dtype: torch.dtype, name: str) -> None: ihelp = IndexHelper.from_numbers(numbers, par) cache = xb.get_cache(numbers, ihelp) - energy = xb.get_energy(positions, cache) + energy = xb.get_energy(pos, cache) energy.sum().backward() - assert positions.grad is not None - grad_backward = positions.grad.clone() + assert pos.grad is not None + grad_backward = pos.grad.clone() # also zero out gradients when using `.backward()` - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol * 10) == grad_backward.cpu() @@ -287,7 +287,7 @@ def test_backward_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -295,14 +295,14 @@ def test_backward_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ihelp = IndexHelper.from_numbers(numbers, par) cache = xb.get_cache(numbers, ihelp) - energy = xb.get_energy(positions, cache) + energy = xb.get_energy(pos, cache) energy.sum().backward() - assert positions.grad is not None - grad_backward = positions.grad.clone() + assert pos.grad is not None + grad_backward = pos.grad.clone() # also zero out gradients when using `.backward()` - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol * 10) == grad_backward.cpu() diff --git a/test/test_halogen/test_hess.py b/test/test_halogen/test_hess.py index 882ad0fe1..72e70d019 100644 --- a/test/test_halogen/test_hess.py +++ b/test/test_halogen/test_hess.py @@ -54,7 +54,7 @@ def test_single(dtype: torch.dtype, name: str) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -65,10 +65,10 @@ def test_single(dtype: torch.dtype, name: str) -> None: def energy(pos: Tensor) -> Tensor: return xb.get_energy(pos, cache).sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() + pos.detach_() hess = hess.detach().reshape_as(ref) assert ref.shape == hess.shape @@ -110,7 +110,7 @@ def skip_test_batch(dtype: torch.dtype, name1: str, name2) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) xb = new_halogen(numbers, par, **dd) assert xb is not None @@ -121,8 +121,8 @@ def skip_test_batch(dtype: torch.dtype, name1: str, name2) -> None: def energy(pos: Tensor) -> Tensor: return xb.get_energy(pos, cache).sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() + pos.detach_() assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == hess.detach().cpu() diff --git a/test/test_hamiltonian/skip_test_grad.py b/test/test_hamiltonian/skip_test_grad.py index d692bc742..4fc372d04 100644 --- a/test/test_hamiltonian/skip_test_grad.py +++ b/test/test_hamiltonian/skip_test_grad.py @@ -255,11 +255,11 @@ def hamiltonian_grad_single(dtype: torch.dtype, name: str) -> None: sample = samples[name] numbers = sample["numbers"].to(DEVICE) positions = sample["positions"].to(**dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) chrg = torch.tensor(0.0, **dd) calc = Calculator(numbers, par, opts=opts, **dd) - result = calc.singlepoint(positions, chrg) + result = calc.singlepoint(pos, chrg) # check setup o = result.integrals.overlap @@ -275,14 +275,14 @@ def hamiltonian_grad_single(dtype: torch.dtype, name: str) -> None: # compare different overlap calculations driver = IntDriverPytorch(numbers, par, calc.ihelp, **dd) - driver.setup(positions) + driver.setup(pos) overlap2 = s.build(driver).detach() overlap = o.matrix.detach() tol = sqrt(torch.finfo(dtype).eps) * 5 assert pytest.approx(overlap2, abs=tol, rel=tol) == overlap - cn = cn_d3(numbers, positions) + cn = cn_d3(numbers, pos) wmat = get_density( result.coefficients, result.occupation.sum(-2), @@ -294,7 +294,7 @@ def hamiltonian_grad_single(dtype: torch.dtype, name: str) -> None: # analytical gradient dedcn, dedr = h.integral.get_gradient( - positions, + pos, o.matrix, doverlap, result.density, @@ -310,19 +310,19 @@ def hamiltonian_grad_single(dtype: torch.dtype, name: str) -> None: # energy = result.scf.sum(-1) # autograd = torch.autograd.grad( # energy, - # positions, + # pos, # )[0] # # # numerical gradient - # positions.requires_grad_(False) - # numerical = calc_numerical_gradient(calc, positions, numbers, chrg) + # pos.requires_grad_(False) + # numerical = calc_numerical_gradient(calc, pos, numbers, chrg) assert pytest.approx(ref, abs=atol) == dedr.detach() # NOTE: dedcn is already tested in test_hamiltonian assert pytest.approx(ref_dedcn, abs=atol) == dedcn.detach() - positions.detach_() + pos.detach_() @pytest.mark.grad diff --git a/test/test_hamiltonian/test_base.py b/test/test_hamiltonian/test_base.py new file mode 100644 index 000000000..439bcf671 --- /dev/null +++ b/test/test_hamiltonian/test_base.py @@ -0,0 +1,69 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +General test for base Hamiltonian. +""" + +from __future__ import annotations + +import tempfile as td +from pathlib import Path + +import pytest +import torch + +from dxtb import GFN1_XTB as par +from dxtb import IndexHelper +from dxtb._src.xtb.gfn1 import GFN1Hamiltonian + + +def test_requires_grad() -> None: + numbers = torch.tensor([1]) + ihelp = IndexHelper.from_numbers(numbers, par) + + h = GFN1Hamiltonian(numbers, par, ihelp) + + h._matrix = None + assert h.requires_grad is False + + h._matrix = torch.tensor([1.0], requires_grad=True) + assert h.requires_grad is True + + +def test_write_to_pt() -> None: + numbers = torch.tensor([3, 1]) + ihelp = IndexHelper.from_numbers(numbers, par) + + h = GFN1Hamiltonian(numbers, par, ihelp) + h._matrix = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + + with td.TemporaryDirectory() as tmpdir: + p_write = Path(tmpdir) / "test.pt" + h.to_pt(p_write) + + read_mat = torch.load(p_write) + assert pytest.approx(h._matrix.cpu()) == read_mat.cpu() + + with td.TemporaryDirectory() as tmpdir: + p_write = Path(tmpdir) / f"{h.label.casefold()}" + + # To test the None case, inject the temporary path via the label + h.label = str(p_write) + h.to_pt() + + read_mat = torch.load(f"{p_write}.pt") + assert pytest.approx(h._matrix.cpu()) == read_mat.cpu() diff --git a/test/test_hamiltonian/test_grad_pos.py b/test/test_hamiltonian/test_grad_pos.py index ea0c83279..f882da2bc 100644 --- a/test/test_hamiltonian/test_grad_pos.py +++ b/test/test_hamiltonian/test_grad_pos.py @@ -27,10 +27,8 @@ from dxtb import GFN1_XTB as par from dxtb import IndexHelper -from dxtb._src.constants import labels -from dxtb._src.integral.container import Overlap from dxtb._src.integral.driver.pytorch import IntDriverPytorch as IntDriver -from dxtb._src.ncoord import cn_d3 +from dxtb._src.integral.driver.pytorch import OverlapPytorch as Overlap from dxtb._src.typing import DD, Callable, Tensor from dxtb._src.xtb.gfn1 import GFN1Hamiltonian @@ -54,20 +52,19 @@ def gradchecker( ihelp = IndexHelper.from_numbers(numbers, par) h0 = GFN1Hamiltonian(numbers, par, ihelp, **dd) - overlap = Overlap(driver=labels.INTDRIVER_ANALYTICAL, **dd) + overlap = Overlap(**dd) driver = IntDriver(numbers, par, ihelp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - driver.setup(positions) + def func(p: Tensor) -> Tensor: + driver.setup(p) s = overlap.build(driver) - cn = cn_d3(numbers, pos) - return h0.build(pos, s, cn=cn) + return h0.build(p, s) - return func, positions + return func, pos @pytest.mark.grad @@ -143,20 +140,19 @@ def gradchecker_batch( ihelp = IndexHelper.from_numbers(numbers, par) h0 = GFN1Hamiltonian(numbers, par, ihelp, **dd) - overlap = Overlap(driver=labels.INTDRIVER_ANALYTICAL, **dd) + overlap = Overlap(**dd) driver = IntDriver(numbers, par, ihelp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - driver.setup(positions, mask=mask) + def func(p: Tensor) -> Tensor: + driver.setup(p, mask=mask) s = overlap.build(driver) - cn = cn_d3(numbers, pos) - return h0.build(pos, s, cn=cn) + return h0.build(p, s) - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_hamiltonian/test_h0.py b/test/test_hamiltonian/test_h0.py index abd382ff7..c9cf06794 100644 --- a/test/test_hamiltonian/test_h0.py +++ b/test/test_hamiltonian/test_h0.py @@ -31,7 +31,6 @@ from dxtb import GFN1_XTB, IndexHelper from dxtb._src.integral.driver.pytorch import IntDriverPytorch as IntDriver from dxtb._src.integral.driver.pytorch import OverlapPytorch as Overlap -from dxtb._src.ncoord import cn_d3 from dxtb._src.param import Param from dxtb._src.typing import DD, Tensor from dxtb._src.xtb.gfn1 import GFN1Hamiltonian @@ -57,8 +56,7 @@ def run(numbers: Tensor, positions: Tensor, par: Param, ref: Tensor, dd: DD) -> driver.setup(positions) s = overlap.build(driver) - cn = cn_d3(numbers, positions) - h = h0.build(positions, s, cn=cn) + h = h0.build(positions, s) assert pytest.approx(h.cpu(), abs=tol) == h.mT.cpu() assert pytest.approx(h.cpu(), abs=tol) == ref.cpu() @@ -82,7 +80,6 @@ def test_single(dtype: torch.dtype, name: str) -> None: def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: """Batched version.""" dd: DD = {"dtype": dtype, "device": DEVICE} - tol = sqrt(torch.finfo(dtype).eps) * 10 sample1, sample2 = samples[name1], samples[name2] @@ -198,7 +195,7 @@ def test_no_cn(dtype: torch.dtype) -> None: ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) driver = IntDriver(numbers, GFN1_XTB, ihelp, **dd) overlap = Overlap(**dd) - h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp, **dd) + h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp, cn=None, **dd) driver.setup(positions) s = overlap.build(driver) diff --git a/test/test_indexhelper/test_extra.py b/test/test_indexhelper/test_extra.py index 3455b04de..741242f47 100644 --- a/test/test_indexhelper/test_extra.py +++ b/test/test_indexhelper/test_extra.py @@ -168,7 +168,8 @@ def test_spread_unique_batch() -> None: x = torch.randn((nbatch, nat_u, 3), device=DEVICE) # pollutes CUDA memory - assert False + if DEVICE is not None: + assert False out = ihelp.spread_uspecies_to_atom(x, dim=-2, extra=True) assert out.shape == torch.Size((nbatch, nat, 3)) diff --git a/test/test_integrals/test_driver/__init__.py b/test/test_integrals/test_driver/__init__.py new file mode 100644 index 000000000..15d042be4 --- /dev/null +++ b/test/test_integrals/test_driver/__init__.py @@ -0,0 +1,16 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/test_integrals/test_driver/test_factory.py b/test/test_integrals/test_driver/test_factory.py new file mode 100644 index 000000000..0460db86a --- /dev/null +++ b/test/test_integrals/test_driver/test_factory.py @@ -0,0 +1,70 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test factories for integral drivers. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, labels +from dxtb._src.integral.driver import factory +from dxtb._src.integral.driver.libcint import IntDriverLibcint +from dxtb._src.integral.driver.pytorch import ( + IntDriverPytorch, + IntDriverPytorchLegacy, + IntDriverPytorchNoAnalytical, +) + +numbers = torch.tensor([14, 1, 1, 1, 1]) + + +def test_fail() -> None: + with pytest.raises(ValueError): + factory.new_driver(-1, numbers, GFN1_XTB) + + +def test_driver() -> None: + cls = factory.new_driver(labels.INTDRIVER_LIBCINT, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverLibcint) + + cls = factory.new_driver(labels.INTDRIVER_ANALYTICAL, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorch) + + cls = factory.new_driver(labels.INTDRIVER_AUTOGRAD, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchNoAnalytical) + + cls = factory.new_driver(labels.INTDRIVER_LEGACY, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchLegacy) + + +def test_libcint() -> None: + cls = factory.new_driver_libcint(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverLibcint) + + +def test_pytorch() -> None: + cls = factory.new_driver_pytorch(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorch) + + cls = factory.new_driver_pytorch_no_analytical(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchNoAnalytical) + + cls = factory.new_driver_legacy(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchLegacy) diff --git a/test/test_integrals/test_driver.py b/test/test_integrals/test_driver/test_manager.py similarity index 54% rename from test/test_integrals/test_driver.py rename to test/test_integrals/test_driver/test_manager.py index 428b1bbe8..a448b51b6 100644 --- a/test/test_integrals/test_driver.py +++ b/test/test_integrals/test_driver/test_manager.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Test overlap build from integral container. +Test the integral driver manager. """ from __future__ import annotations @@ -25,98 +25,95 @@ from dxtb import GFN1_XTB as par from dxtb import IndexHelper -from dxtb import integrals as ints from dxtb._src.constants.labels import INTDRIVER_ANALYTICAL, INTDRIVER_LIBCINT from dxtb._src.integral.driver.libcint import IntDriverLibcint +from dxtb._src.integral.driver.manager import DriverManager from dxtb._src.integral.driver.pytorch import IntDriverPytorch from dxtb._src.typing import DD -from ..conftest import DEVICE +from ...conftest import DEVICE + + +def test_fail() -> None: + mgr = DriverManager(-99) + + with pytest.raises(RuntimeError): + _ = mgr.driver + + with pytest.raises(ValueError): + numbers = torch.tensor([1, 2], device=DEVICE) + mgr.create_driver(numbers, par, IndexHelper.from_numbers(numbers, par)) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("force_cpu_for_libcint", [True, False]) def test_single(dtype: torch.dtype, force_cpu_for_libcint: bool): - """Overlap matrix for monoatomic molecule should be unity.""" dd: DD = {"dtype": dtype, "device": DEVICE} numbers = torch.tensor([3, 1], device=DEVICE) positions = torch.zeros((2, 3), **dd) ihelp = IndexHelper.from_numbers(numbers, par) - ipy = ints.Integrals( - numbers, - par, - ihelp, - driver=INTDRIVER_ANALYTICAL, - force_cpu_for_libcint=force_cpu_for_libcint, - **dd, + + mgr_py = DriverManager( + INTDRIVER_ANALYTICAL, force_cpu_for_libcint=force_cpu_for_libcint, **dd ) - ilc = ints.Integrals( - numbers, - par, - ihelp, - driver=INTDRIVER_LIBCINT, - force_cpu_for_libcint=force_cpu_for_libcint, - **dd, + mgr_py.create_driver(numbers, par, ihelp) + + mgr_lc = DriverManager( + INTDRIVER_LIBCINT, force_cpu_for_libcint=force_cpu_for_libcint, **dd ) + mgr_lc.create_driver(numbers, par, ihelp) if force_cpu_for_libcint is True: positions = positions.cpu() - ipy.setup_driver(positions) - assert isinstance(ipy.driver, IntDriverPytorch) - ilc.setup_driver(positions) - assert isinstance(ilc.driver, IntDriverLibcint) + mgr_py.setup_driver(positions) + assert isinstance(mgr_py.driver, IntDriverPytorch) + mgr_lc.setup_driver(positions) + assert isinstance(mgr_lc.driver, IntDriverLibcint) - assert ipy.driver.is_latest(positions) is True - assert ilc.driver.is_latest(positions) is True + assert mgr_py.driver.is_latest(positions) is True + assert mgr_lc.driver.is_latest(positions) is True # upon changing the positions, the driver should become outdated positions[0, 0] += 1e-4 - assert ipy.driver.is_latest(positions) is False - assert ilc.driver.is_latest(positions) is False + assert mgr_py.driver.is_latest(positions) is False + assert mgr_lc.driver.is_latest(positions) is False @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("force_cpu_for_libcint", [True, False]) def test_batch(dtype: torch.dtype, force_cpu_for_libcint: bool) -> None: - """Overlap matrix for monoatomic molecule should be unity.""" dd: DD = {"dtype": dtype, "device": DEVICE} numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) positions = torch.zeros((2, 2, 3), **dd) ihelp = IndexHelper.from_numbers(numbers, par) - ipy = ints.Integrals( - numbers, - par, - ihelp, - driver=INTDRIVER_ANALYTICAL, - force_cpu_for_libcint=force_cpu_for_libcint, - **dd, + + mgr_py = DriverManager( + INTDRIVER_ANALYTICAL, force_cpu_for_libcint=force_cpu_for_libcint, **dd ) - ilc = ints.Integrals( - numbers, - par, - ihelp, - driver=INTDRIVER_LIBCINT, - force_cpu_for_libcint=force_cpu_for_libcint, - **dd, + mgr_py.create_driver(numbers, par, ihelp) + + mgr_lc = DriverManager( + INTDRIVER_LIBCINT, force_cpu_for_libcint=force_cpu_for_libcint, **dd ) + mgr_lc.create_driver(numbers, par, ihelp) if force_cpu_for_libcint is True: positions = positions.cpu() - ipy.setup_driver(positions) - assert isinstance(ipy.driver, IntDriverPytorch) - ilc.setup_driver(positions) - assert isinstance(ilc.driver, IntDriverLibcint) + mgr_py.setup_driver(positions) + assert isinstance(mgr_py.driver, IntDriverPytorch) + mgr_lc.setup_driver(positions) + assert isinstance(mgr_lc.driver, IntDriverLibcint) - assert ipy.driver.is_latest(positions) is True - assert ilc.driver.is_latest(positions) is True + assert mgr_py.driver.is_latest(positions) is True + assert mgr_lc.driver.is_latest(positions) is True # upon changing the positions, the driver should become outdated positions[0, 0] += 1e-4 - assert ipy.driver.is_latest(positions) is False - assert ilc.driver.is_latest(positions) is False + assert mgr_py.driver.is_latest(positions) is False + assert mgr_lc.driver.is_latest(positions) is False diff --git a/test/test_integrals/test_driver/test_pytorch.py b/test/test_integrals/test_driver/test_pytorch.py new file mode 100644 index 000000000..5cdc15815 --- /dev/null +++ b/test/test_integrals/test_driver/test_pytorch.py @@ -0,0 +1,168 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the PyTorch integral driver. +""" + +from __future__ import annotations + +import pytest +import torch +from tad_mctc.batch import pack + +from dxtb import GFN1_XTB, IndexHelper +from dxtb._src.integral.driver.pytorch import ( + DipolePytorch, + OverlapPytorch, + QuadrupolePytorch, +) +from dxtb._src.integral.driver.pytorch.driver import BaseIntDriverPytorch +from dxtb._src.typing import DD + +from ...conftest import DEVICE + + +def test_overlap_fail() -> None: + with pytest.raises(ValueError): + _ = OverlapPytorch("wrong") # type: ignore + + +def test_dipole_fail() -> None: + with pytest.raises(NotImplementedError): + _ = DipolePytorch() + + with pytest.raises(ValueError): + _ = DipolePytorch("wrong") # type: ignore + + +def test_quadrupole_fail() -> None: + with pytest.raises(NotImplementedError): + _ = QuadrupolePytorch() + + with pytest.raises(ValueError): + _ = QuadrupolePytorch("wrong") # type: ignore + + +############################################################################## + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_single(dtype: torch.dtype): + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.zeros((2, 3), **dd) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions) + + assert drv._basis is not None + assert drv._positions is not None + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode_fail(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions = torch.zeros((2, 2, 3), **dd) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + # set to invalid value + ihelp.batch_mode = -99 + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + + with pytest.raises(ValueError): + drv.setup(positions) + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode1(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions = pack( + [ + torch.tensor([[0.0, 0.0, +1.0], [0.0, 0.0, -1.0]], **dd), + torch.tensor([[0.0, 0.0, 2.0]], **dd), + ], + return_mask=False, + ) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB, batch_mode=1) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions) + + assert drv._basis_batch is not None + assert len(drv._basis_batch) == 2 + + assert drv._positions_batch is not None + assert len(drv._positions_batch) == 2 + + assert drv._positions_batch[0].shape == (2, 3) + assert (drv._positions_batch[0] == positions[0, :, :]).all() + assert drv._positions_batch[1].shape == (1, 3) + assert (drv._positions_batch[1] == positions[1, 0, :]).all() + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode1_mask(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions, mask = pack( + [ + torch.tensor([[0.0, 0.0, +1.0], [0.0, 0.0, -1.0]], **dd), + torch.tensor([[0.0, 0.0, 2.0]], **dd), + ], + return_mask=True, + ) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB, batch_mode=1) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions, mask=mask) + + assert drv._basis_batch is not None + assert len(drv._basis_batch) == 2 + + assert drv._positions_batch is not None + assert len(drv._positions_batch) == 2 + + assert drv._positions_batch[0].shape == (2, 3) + assert (drv._positions_batch[0] == positions[0, :, :]).all() + assert drv._positions_batch[1].shape == (1, 3) + assert (drv._positions_batch[1] == positions[1, 0, :]).all() + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode2(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions = torch.zeros((2, 2, 3), **dd) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB, batch_mode=2) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions) + + assert drv._basis_batch is not None + assert len(drv._basis_batch) == 2 + + assert drv._positions_batch is not None + assert len(drv._positions_batch) == 2 diff --git a/test/test_integrals/test_factory.py b/test/test_integrals/test_factory.py new file mode 100644 index 000000000..17e945a15 --- /dev/null +++ b/test/test_integrals/test_factory.py @@ -0,0 +1,191 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test factories for integral classes. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper, labels +from dxtb._src.integral import factory +from dxtb._src.xtb.gfn1 import GFN1Hamiltonian +from dxtb._src.xtb.gfn2 import GFN2Hamiltonian +from dxtb.integrals import factories, types + +numbers = torch.tensor([14, 1, 1, 1, 1]) +positions = torch.tensor( + [ + [+0.00000000000000, +0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], + [-1.61768389755830, -1.61768389755830, -1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], + ] +) + + +def test_fail() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + par1.meta = None + factories.new_hcore(numbers, par1, ihelp) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = None + factories.new_hcore(numbers, par1, ihelp) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = "fail" + factories.new_hcore(numbers, par1, ihelp) + + +def test_hcore() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + h0_gfn1 = factory.new_hcore(numbers, GFN1_XTB, ihelp) + assert isinstance(h0_gfn1, GFN1Hamiltonian) + + ihelp = IndexHelper.from_numbers(numbers, GFN2_XTB) + h0_gfn2 = factory.new_hcore(numbers, GFN2_XTB, ihelp) + assert isinstance(h0_gfn2, GFN2Hamiltonian) + + +def test_hcore_gfn1() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + h0 = factory.new_hcore_gfn1(numbers, ihelp) + assert isinstance(h0, GFN1Hamiltonian) + + h0 = factory.new_hcore_gfn1(numbers, ihelp, GFN1_XTB) + assert isinstance(h0, GFN1Hamiltonian) + + +def test_hcore_gfn2() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN2_XTB) + + h0 = factory.new_hcore_gfn2(numbers, ihelp) + assert isinstance(h0, GFN2Hamiltonian) + + h0 = factory.new_hcore_gfn2(numbers, ihelp, GFN2_XTB) + assert isinstance(h0, GFN2Hamiltonian) + + +################################################################################ + + +def test_overlap_fail() -> None: + with pytest.raises(ValueError): + factory.new_overlap(-1) + + +def test_overlap() -> None: + cls = factory.new_overlap(labels.INTDRIVER_LIBCINT) + assert isinstance(cls, types.OverlapIntegral) + + cls = factory.new_overlap(labels.INTDRIVER_ANALYTICAL) + assert isinstance(cls, types.OverlapIntegral) + + +def test_overlap_libcint() -> None: + cls = factory.new_overlap_libcint() + assert isinstance(cls, types.OverlapIntegral) + assert cls.device == torch.device("cpu") + + cls = factory.new_overlap_libcint(force_cpu_for_libcint=True) + assert isinstance(cls, types.OverlapIntegral) + assert cls.device == torch.device("cpu") + + +def test_overlap_pytorch() -> None: + cls = factory.new_overlap_pytorch() + assert isinstance(cls, types.OverlapIntegral) + + +################################################################################ + + +def test_dipint_fail() -> None: + with pytest.raises(ValueError): + factory.new_dipint(-1) + + +def test_dipint() -> None: + cls = factory.new_dipint(labels.INTDRIVER_LIBCINT) + assert isinstance(cls, types.DipoleIntegral) + + with pytest.raises(NotImplementedError): + cls = factory.new_dipint(labels.INTDRIVER_ANALYTICAL) + assert isinstance(cls, types.DipoleIntegral) + + +def test_dipint_libcint() -> None: + cls = factory.new_dipint_libcint() + assert isinstance(cls, types.DipoleIntegral) + assert cls.device == torch.device("cpu") + + cls = factory.new_dipint_libcint(force_cpu_for_libcint=True) + assert isinstance(cls, types.DipoleIntegral) + assert cls.device == torch.device("cpu") + + +def test_dipint_pytorch() -> None: + with pytest.raises(NotImplementedError): + cls = factory.new_dipint_pytorch() + assert isinstance(cls, types.DipoleIntegral) + + +################################################################################ + + +def test_quadint_fail() -> None: + with pytest.raises(ValueError): + factory.new_quadint(-1) + + +def test_quadint() -> None: + cls = factory.new_quadint(labels.INTDRIVER_LIBCINT) + assert isinstance(cls, types.QuadrupoleIntegral) + + with pytest.raises(NotImplementedError): + cls = factory.new_quadint(labels.INTDRIVER_ANALYTICAL) + assert isinstance(cls, types.QuadrupoleIntegral) + + +def test_quadint_libcint() -> None: + cls = factory.new_quadint_libcint() + assert isinstance(cls, types.QuadrupoleIntegral) + assert cls.device == torch.device("cpu") + + cls = factory.new_quadint_libcint(force_cpu_for_libcint=True) + assert isinstance(cls, types.QuadrupoleIntegral) + assert cls.device == torch.device("cpu") + + +def test_quadint_pytorch() -> None: + with pytest.raises(NotImplementedError): + cls = factory.new_quadint_pytorch() + assert isinstance(cls, types.QuadrupoleIntegral) diff --git a/test/test_integrals/test_general.py b/test/test_integrals/test_general.py index 4f0720706..1a72ffd4e 100644 --- a/test/test_integrals/test_general.py +++ b/test/test_integrals/test_general.py @@ -27,7 +27,9 @@ from dxtb import IndexHelper from dxtb import integrals as ints from dxtb._src.constants.labels import INTDRIVER_ANALYTICAL, INTDRIVER_LIBCINT +from dxtb._src.integral.driver import libcint, pytorch from dxtb._src.typing import DD +from dxtb._src.xtb.gfn1 import GFN1Hamiltonian from ..conftest import DEVICE @@ -37,8 +39,8 @@ def test_empty(dtype: torch.dtype): dd: DD = {"dtype": dtype, "device": DEVICE} numbers = torch.tensor([1, 3], device=DEVICE) - ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals(numbers, par, ihelp, **dd) + mgr = ints.DriverManager(INTDRIVER_LIBCINT, **dd) + i = ints.Integrals(mgr, **dd) assert i._hcore is None assert i._overlap is None @@ -57,17 +59,20 @@ def test_fail_family(dtype: torch.dtype): numbers = torch.tensor([1, 3], device=DEVICE) ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals(numbers, par, ihelp, driver=INTDRIVER_ANALYTICAL, **dd) + mgr = ints.DriverManager(INTDRIVER_ANALYTICAL, **dd) + mgr.create_driver(numbers, par, ihelp) + + i = ints.Integrals(mgr, **dd) # make sure the checks are turned on assert i.run_checks is True with pytest.raises(RuntimeError): - i.overlap = ints.types.Overlap(INTDRIVER_LIBCINT, **dd) + i.overlap = libcint.OverlapLibcint(**dd) with pytest.raises(RuntimeError): - i.dipole = ints.types.Dipole(INTDRIVER_LIBCINT, **dd) + i.dipole = libcint.DipoleLibcint(**dd) with pytest.raises(RuntimeError): - i.quadrupole = ints.types.Quadrupole(INTDRIVER_LIBCINT, **dd) + i.quadrupole = libcint.QuadrupoleLibcint(**dd) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -76,20 +81,23 @@ def test_fail_pytorch_multipole(dtype: torch.dtype): numbers = torch.tensor([1, 3], device=DEVICE) ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals(numbers, par, ihelp, driver=INTDRIVER_ANALYTICAL, **dd) + mgr = ints.DriverManager(INTDRIVER_LIBCINT, **dd) + mgr.create_driver(numbers, par, ihelp) + + i = ints.Integrals(mgr, **dd) # make sure the checks are turned on assert i.run_checks is True # incompatible driver with pytest.raises(RuntimeError): - i.overlap = ints.types.Overlap(INTDRIVER_LIBCINT, **dd) + i.overlap = pytorch.OverlapPytorch(**dd) # multipole moments not implemented with PyTorch with pytest.raises(NotImplementedError): - i.dipole = ints.types.Dipole(INTDRIVER_ANALYTICAL, **dd) + i.dipole = pytorch.DipolePytorch(**dd) with pytest.raises(NotImplementedError): - i.quadrupole = ints.types.Quadrupole(INTDRIVER_ANALYTICAL, **dd) + i.quadrupole = pytorch.QuadrupolePytorch(**dd) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -98,9 +106,12 @@ def test_hcore(dtype: torch.dtype): numbers = torch.tensor([1, 3], device=DEVICE) ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals(numbers, par, ihelp, **dd) - i.hcore = ints.types.HCore(numbers, par, ihelp, **dd) + mgr = ints.DriverManager(INTDRIVER_ANALYTICAL, **dd) + mgr.create_driver(numbers, par, ihelp) + + i = ints.Integrals(mgr, **dd) + i.hcore = GFN1Hamiltonian(numbers, par, ihelp, **dd) h = i.hcore assert h is not None - assert h.integral.matrix is None + assert h.matrix is None diff --git a/test/test_integrals/test_libcint.py b/test/test_integrals/test_libcint.py index 6f0c24ad6..83e630117 100644 --- a/test/test_integrals/test_libcint.py +++ b/test/test_integrals/test_libcint.py @@ -28,42 +28,40 @@ from dxtb import IndexHelper from dxtb import integrals as ints from dxtb import labels -from dxtb._src.exlibs import libcint -from dxtb._src.integral.driver.libcint import IntDriverLibcint -from dxtb._src.typing import DD +from dxtb._src.exlibs.libcint import LibcintWrapper +from dxtb._src.integral.driver import libcint +from dxtb._src.integral.driver.manager import DriverManager +from dxtb._src.integral.factory import ( + new_dipint_libcint, + new_overlap_libcint, + new_quadint_libcint, +) +from dxtb._src.typing import DD, Tensor from ..conftest import DEVICE from .samples import samples -@pytest.mark.parametrize("name", ["H2"]) -@pytest.mark.parametrize("dtype", [torch.float, torch.double]) -@pytest.mark.parametrize("force_cpu_for_libcint", [True, False]) -def test_single(dtype: torch.dtype, name: str, force_cpu_for_libcint: bool): - """Overlap matrix for monoatomic molecule should be unity.""" - dd: DD = {"dtype": dtype, "device": DEVICE} - - sample = samples[name] - numbers = sample["numbers"].to(DEVICE) - positions = sample["positions"].to(**dd) - +def run(numbers: Tensor, positions: Tensor, cpu: bool, dd: DD) -> None: ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals( - numbers, - par, - ihelp, - force_cpu_for_libcint=force_cpu_for_libcint, - intlevel=labels.INTLEVEL_QUADRUPOLE, - **dd, - ) + mgr = DriverManager(labels.INTDRIVER_LIBCINT, force_cpu_for_libcint=cpu, **dd) + mgr.create_driver(numbers, par, ihelp) + + i = ints.Integrals(mgr, intlevel=labels.INTLEVEL_QUADRUPOLE, **dd) + i.build_overlap(positions, force_cpu_for_libcint=cpu) - i.setup_driver(positions) - assert isinstance(i.driver, IntDriverLibcint) - assert isinstance(i.driver.drv, libcint.LibcintWrapper) + if numbers.ndim == 1: + assert isinstance(mgr.driver, libcint.IntDriverLibcint) + assert isinstance(mgr.driver.drv, LibcintWrapper) + else: + assert isinstance(mgr.driver, libcint.IntDriverLibcint) + assert isinstance(mgr.driver.drv, list) + assert isinstance(mgr.driver.drv[0], LibcintWrapper) + assert isinstance(mgr.driver.drv[1], LibcintWrapper) ################################################ - i.overlap = ints.types.Overlap(**dd) + i.overlap = new_overlap_libcint(**dd, force_cpu_for_libcint=cpu) i.build_overlap(positions) o = i.overlap @@ -72,7 +70,7 @@ def test_single(dtype: torch.dtype, name: str, force_cpu_for_libcint: bool): ################################################ - i.dipole = ints.types.Dipole(**dd) + i.dipole = new_dipint_libcint(**dd, force_cpu_for_libcint=cpu) i.build_dipole(positions) d = i.dipole @@ -81,7 +79,7 @@ def test_single(dtype: torch.dtype, name: str, force_cpu_for_libcint: bool): ################################################ - i.quadrupole = ints.types.Quadrupole(**dd) + i.quadrupole = new_quadint_libcint(**dd, force_cpu_for_libcint=cpu) i.build_quadrupole(positions) q = i.quadrupole @@ -89,6 +87,19 @@ def test_single(dtype: torch.dtype, name: str, force_cpu_for_libcint: bool): assert q.matrix is not None +@pytest.mark.parametrize("name", ["H2"]) +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +@pytest.mark.parametrize("force_cpu_for_libcint", [True, False]) +def test_single(dtype: torch.dtype, name: str, force_cpu_for_libcint: bool): + dd: DD = {"dtype": dtype, "device": DEVICE} + + sample = samples[name] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + + run(numbers, positions, force_cpu_for_libcint, dd) + + @pytest.mark.parametrize("name1", ["H2"]) @pytest.mark.parametrize("name2", ["LiH"]) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -96,7 +107,6 @@ def test_single(dtype: torch.dtype, name: str, force_cpu_for_libcint: bool): def test_batch( dtype: torch.dtype, name1: str, name2: str, force_cpu_for_libcint: bool ) -> None: - """Overlap matrix for monoatomic molecule should be unity.""" dd: DD = {"dtype": dtype, "device": DEVICE} sample1, sample2 = samples[name1], samples[name2] @@ -114,45 +124,4 @@ def test_batch( ) ) - ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals( - numbers, - par, - ihelp, - intlevel=labels.INTLEVEL_QUADRUPOLE, - force_cpu_for_libcint=force_cpu_for_libcint, - **dd, - ) - - i.setup_driver(positions) - assert isinstance(i.driver, IntDriverLibcint) - assert isinstance(i.driver.drv, list) - assert isinstance(i.driver.drv[0], libcint.LibcintWrapper) - assert isinstance(i.driver.drv[1], libcint.LibcintWrapper) - - ################################################ - - i.overlap = ints.types.Overlap(**dd) - i.build_overlap(positions) - - o = i.overlap - assert o is not None - assert o.matrix is not None - - ################################################ - - i.dipole = ints.types.Dipole(**dd) - i.build_dipole(positions) - - d = i.dipole - assert d is not None - assert d.matrix is not None - - ################################################ - - i.quadrupole = ints.types.Quadrupole(**dd) - i.build_quadrupole(positions) - - q = i.quadrupole - assert q is not None - assert q.matrix is not None + run(numbers, positions, force_cpu_for_libcint, dd) diff --git a/test/test_integrals/test_pytorch.py b/test/test_integrals/test_pytorch.py index 3cdeac8df..b304cc085 100644 --- a/test/test_integrals/test_pytorch.py +++ b/test/test_integrals/test_pytorch.py @@ -28,7 +28,8 @@ from dxtb import IndexHelper from dxtb import integrals as ints from dxtb._src.constants.labels import INTDRIVER_ANALYTICAL -from dxtb._src.integral.driver.pytorch import IntDriverPytorch +from dxtb._src.integral.driver.manager import DriverManager +from dxtb._src.integral.driver.pytorch import OverlapPytorch from dxtb._src.typing import DD, Tensor from ..conftest import DEVICE @@ -37,12 +38,10 @@ def run(numbers: Tensor, positions: Tensor, dd: DD) -> None: ihelp = IndexHelper.from_numbers(numbers, par) - i = ints.Integrals(numbers, par, ihelp, driver=INTDRIVER_ANALYTICAL, **dd) + mgr = DriverManager(INTDRIVER_ANALYTICAL, **dd) + mgr.create_driver(numbers, par, ihelp) - i.setup_driver(positions) - assert isinstance(i.driver, IntDriverPytorch) - - i.overlap = ints.types.Overlap(driver=INTDRIVER_ANALYTICAL, **dd) + i = ints.Integrals(mgr, _overlap=OverlapPytorch(**dd), **dd) i.build_overlap(positions) o = i.overlap @@ -53,7 +52,6 @@ def run(numbers: Tensor, positions: Tensor, dd: DD) -> None: @pytest.mark.parametrize("name", ["H2"]) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_single(dtype: torch.dtype, name: str): - """Overlap matrix for monoatomic molecule should be unity.""" dd: DD = {"dtype": dtype, "device": DEVICE} sample = samples[name] @@ -67,7 +65,6 @@ def test_single(dtype: torch.dtype, name: str): @pytest.mark.parametrize("name2", ["LiH"]) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_batch(dtype: torch.dtype, name1: str, name2: str): - """Overlap matrix for monoatomic molecule should be unity.""" dd: DD = {"dtype": dtype, "device": DEVICE} sample1, sample2 = samples[name1], samples[name2] diff --git a/test/test_integrals/test_types.py b/test/test_integrals/test_types.py index ff8f13796..fb7dbff4a 100644 --- a/test/test_integrals/test_types.py +++ b/test/test_integrals/test_types.py @@ -23,10 +23,19 @@ import pytest import torch -from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper -from dxtb.integrals import types as inttypes +from dxtb import GFN1_XTB, IndexHelper, labels +from dxtb.integrals.factories import new_dipint, new_hcore, new_quadint numbers = torch.tensor([14, 1, 1, 1, 1]) +positions = torch.tensor( + [ + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], + [-1.61768389755830, -1.61768389755830, -1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], + ] +) def test_fail() -> None: @@ -37,4 +46,25 @@ def test_fail() -> None: assert par1.meta is not None par1.meta.name = "fail" - inttypes.HCore(numbers, par1, ihelp) + new_hcore(numbers, par1, ihelp) + + +def test_dipole_fail() -> None: + i = new_dipint(labels.INTDRIVER_LIBCINT) + + with pytest.raises(RuntimeError): + fake_ovlp = torch.eye(3, dtype=torch.float64) + i.shift_r0_rj(fake_ovlp, positions) + + +def test_quadrupole_fail() -> None: + i = new_quadint(labels.INTDRIVER_LIBCINT) + + with pytest.raises(RuntimeError): + fake_ovlp = torch.eye(3, dtype=torch.float64) + fake_r0 = torch.zeros(3, dtype=torch.float64) + i.shift_r0r0_rjrj(fake_r0, fake_ovlp, positions) + + with pytest.raises(RuntimeError): + i._matrix = torch.eye(3, dtype=torch.float64) + i.traceless() diff --git a/test/test_integrals/test_wrappers.py b/test/test_integrals/test_wrappers.py index 547e77410..4bb734dbf 100644 --- a/test/test_integrals/test_wrappers.py +++ b/test/test_integrals/test_wrappers.py @@ -39,12 +39,12 @@ def test_fail() -> None: - with pytest.raises(TypeError): + with pytest.raises(ValueError): par1 = GFN1_XTB.model_copy(deep=True) par1.meta = None wrappers.hcore(numbers, positions, par1) - with pytest.raises(TypeError): + with pytest.raises(ValueError): par1 = GFN1_XTB.model_copy(deep=True) assert par1.meta is not None @@ -70,7 +70,7 @@ def test_h0_gfn1(par: Param) -> None: h0 = wrappers.hcore(numbers, positions, par) assert h0.shape == (17, 17) - h0 = wrappers.hcore(numbers, positions, par, cn=torch.zeros(numbers.shape)) + h0 = wrappers.hcore(numbers, positions, par, cn=None) assert h0.shape == (17, 17) diff --git a/test/test_interaction/test_grad.py b/test/test_interaction/test_grad.py index 0d7efa72e..38403a9e4 100644 --- a/test/test_interaction/test_grad.py +++ b/test/test_interaction/test_grad.py @@ -56,14 +56,14 @@ def gradchecker( ilist = InteractionList(new_es2(numbers, par, **dd), new_es3(numbers, par, **dd)) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - icaches = ilist.get_cache(numbers=numbers, positions=pos, ihelp=ihelp) - charges = get_guess(numbers, positions, chrg, ihelp) + def func(p: Tensor) -> Tensor: + icaches = ilist.get_cache(numbers=numbers, positions=p, ihelp=ihelp) + charges = get_guess(numbers, p, chrg, ihelp) return ilist.get_energy(charges, icaches, ihelp) - return func, positions + return func, pos @pytest.mark.grad @@ -116,14 +116,14 @@ def gradchecker_batch( ilist = InteractionList(new_es2(numbers, par, **dd), new_es3(numbers, par, **dd)) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - icaches = ilist.get_cache(numbers=numbers, positions=pos, ihelp=ihelp) - charges = get_guess(numbers, positions, chrg, ihelp) + def func(p: Tensor) -> Tensor: + icaches = ilist.get_cache(numbers=numbers, positions=p, ihelp=ihelp) + charges = get_guess(numbers, p, chrg, ihelp) return ilist.get_energy(charges, icaches, ihelp) - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_libcint/test_gradcheck.py b/test/test_libcint/test_gradcheck.py index 0c6a1e914..6b3d44937 100644 --- a/test/test_libcint/test_gradcheck.py +++ b/test/test_libcint/test_gradcheck.py @@ -56,7 +56,7 @@ def gradchecker( bas = Basis(numbers, par, ihelp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) def func(pos: Tensor) -> Tensor: atombases = bas.create_libcint(pos) @@ -65,7 +65,7 @@ def func(pos: Tensor) -> Tensor: wrapper = libcint.LibcintWrapper(atombases, ihelp, spherical=False) return libcint.int1e(intstr, wrapper) - return func, positions + return func, pos @pytest.mark.grad @@ -124,13 +124,13 @@ def gradchecker_batch( overlap = OverlapLibcint(**dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - driver.setup(pos, mask=mask) + def func(p: Tensor) -> Tensor: + driver.setup(p, mask=mask) return overlap.build(driver) - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_libcint/test_overlap_grad.py b/test/test_libcint/test_overlap_grad.py index 94d9bc12c..04abc0d27 100644 --- a/test/test_libcint/test_overlap_grad.py +++ b/test/test_libcint/test_overlap_grad.py @@ -100,8 +100,8 @@ def autograd(name: str, dd: DD, tol: float) -> None: bas = Basis(numbers, par, ihelp, **dd) # variable to be differentiated - positions.requires_grad_(True) - atombases = bas.create_libcint(positions) + pos = positions.clone().requires_grad_(True) + atombases = bas.create_libcint(pos) assert is_basis_list(atombases) wrapper = libcint.LibcintWrapper(atombases, ihelp) @@ -109,9 +109,7 @@ def autograd(name: str, dd: DD, tol: float) -> None: norm = torch.pow(s.diagonal(dim1=-1, dim2=-2), -0.5) s = einsum("...ij,...i,...j->...ij", s, norm, norm) - (g,) = torch.autograd.grad(s.sum(), positions) - positions.detach_() - + (g,) = torch.autograd.grad(s.sum(), pos) assert pytest.approx(ref.cpu(), abs=tol) == g.cpu() diff --git a/test/test_multipole/todo_test_dipole_grad.py b/test/test_multipole/todo_test_dipole_grad.py index 29fceda61..ea7fb80af 100644 --- a/test/test_multipole/todo_test_dipole_grad.py +++ b/test/test_multipole/todo_test_dipole_grad.py @@ -91,12 +91,12 @@ def test_grad(dtype: torch.dtype, name: str): numbers = sample["numbers"].to(DEVICE) positions = sample["positions"].to(**dd) positions[0] = torch.tensor([0, 0, 0], **dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) ihelp = IndexHelper.from_numbers(numbers, par) bas = Basis(numbers, par, ihelp, **dd) - atombases = bas.create_libcint(positions) + atombases = bas.create_libcint(pos) assert is_basis_list(atombases) INTSTR = "r0" @@ -110,7 +110,7 @@ def test_grad(dtype: torch.dtype, name: str): # assert False print(igrad.shape) - numgrad = num_grad(numbers, ihelp, positions, INTSTR) + numgrad = num_grad(numbers, ihelp, pos, INTSTR) print("numgrad\n", numgrad) print("") print("") diff --git a/test/test_overlap/test_grad_pos.py b/test/test_overlap/test_grad_pos.py index 83c97e852..34ecd079b 100644 --- a/test/test_overlap/test_grad_pos.py +++ b/test/test_overlap/test_grad_pos.py @@ -31,7 +31,7 @@ from dxtb._src.integral.driver.pytorch import OverlapPytorch as Overlap from dxtb._src.typing import DD, Callable, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, NONDET_TOL from .samples import samples slist = ["LiH", "H2O"] @@ -55,13 +55,13 @@ def gradchecker( overlap = Overlap(uplo="n", **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - driver.setup(pos) + def func(p: Tensor) -> Tensor: + driver.setup(p) return overlap.build(driver) - return func, positions + return func, pos @pytest.mark.grad @@ -73,7 +73,7 @@ def test_grad(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -85,7 +85,7 @@ def test_grad_large(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -97,7 +97,7 @@ def test_gradgrad(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -109,7 +109,7 @@ def test_gradgrad_large(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) def gradchecker_batch( @@ -141,13 +141,13 @@ def gradchecker_batch( overlap = Overlap(uplo="n", **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - driver.setup(pos, mask=mask) + def func(p: Tensor) -> Tensor: + driver.setup(p, mask=mask) return overlap.build(driver) - return func, positions + return func, pos @pytest.mark.grad @@ -161,7 +161,7 @@ def test_grad_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -175,7 +175,7 @@ def test_grad_batch_large(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert dgradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -189,7 +189,7 @@ def test_gradgrad_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) @pytest.mark.grad @@ -203,4 +203,4 @@ def test_gradgrad_batch_large(dtype: torch.dtype, name1: str, name2: str) -> Non gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert dgradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL) diff --git a/test/test_overlap/test_gradient_grad_pos.py b/test/test_overlap/test_gradient_grad_pos.py index 4ae48568b..dd49502ce 100644 --- a/test/test_overlap/test_gradient_grad_pos.py +++ b/test/test_overlap/test_gradient_grad_pos.py @@ -53,12 +53,12 @@ def gradchecker( bas = Basis(torch.unique(numbers), par, ihelp, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - return overlap_gradient(pos, bas, ihelp, uplo=uplo) + def func(p: Tensor) -> Tensor: + return overlap_gradient(p, bas, ihelp, uplo=uplo) - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_param/test_param.py b/test/test_param/test_param.py index 44bc2ea75..228497b1d 100644 --- a/test/test_param/test_param.py +++ b/test/test_param/test_param.py @@ -138,7 +138,7 @@ def test_param_calculator(dtype: torch.dtype) -> None: h = calc.integrals.hcore assert h is not None - occ = calc.ihelp.reduce_shell_to_atom(h.integral.refocc) + occ = calc.ihelp.reduce_shell_to_atom(h.refocc) assert pytest.approx(ref.cpu()) == occ.cpu() diff --git a/test/test_properties/test_forces.py b/test/test_properties/test_forces.py index 67e249091..33a3aa968 100644 --- a/test/test_properties/test_forces.py +++ b/test/test_properties/test_forces.py @@ -57,14 +57,14 @@ def skip_test_autograd(dtype: torch.dtype, name: str) -> None: charge = torch.tensor(0.0, **dd) # required for autodiff of energy w.r.t. positions - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) calc = Calculator(numbers, par, opts=opts, **dd) - def f(pos: Tensor) -> Tensor: - return calc.forces(pos, charge) + def f(p: Tensor) -> Tensor: + return calc.forces(p, charge) - assert dgradcheck(f, positions) + assert dgradcheck(f, pos) def single( diff --git a/test/test_properties/test_hessian.py b/test/test_properties/test_hessian.py index 819072826..6a12cec58 100644 --- a/test/test_properties/test_hessian.py +++ b/test/test_properties/test_hessian.py @@ -59,14 +59,14 @@ def skip_test_autograd(dtype: torch.dtype, name: str) -> None: charge = torch.tensor(0.0, **dd) # required for autodiff of energy w.r.t. positions - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) calc = Calculator(numbers, par, opts=opts, **dd) - def f(pos: Tensor) -> Tensor: - return calc.hessian(pos, charge) + def f(p: Tensor) -> Tensor: + return calc.hessian(p, charge) - assert dgradcheck(f, positions) + assert dgradcheck(f, pos) def single( diff --git a/test/test_properties/test_vibration.py b/test/test_properties/test_vibration.py index 419c904b2..da8f7a67f 100644 --- a/test/test_properties/test_vibration.py +++ b/test/test_properties/test_vibration.py @@ -60,15 +60,15 @@ def skip_test_autograd(dtype: torch.dtype, name: str) -> None: charge = torch.tensor(0.0, **dd) # required for autodiff of energy w.r.t. efield and dipole - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) calc = Calculator(numbers, par, opts=opts, **dd) - def f(pos: Tensor) -> tuple[Tensor, Tensor]: - f, m = calc.vibration(pos, charge) + def f(p: Tensor) -> tuple[Tensor, Tensor]: + f, m = calc.vibration(p, charge) return f, m - assert dgradcheck(f, positions) + assert dgradcheck(f, pos) def single( diff --git a/test/test_properties/todo_test_quadrupole.py b/test/test_properties/todo_test_quadrupole.py index a36e616d5..c1f83d78d 100644 --- a/test/test_properties/todo_test_quadrupole.py +++ b/test/test_properties/todo_test_quadrupole.py @@ -176,7 +176,7 @@ def batched( # required for autodiff of energy w.r.t. efield and quadrupole if use_functorch is True: field_vector.requires_grad_(True) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) field_grad = torch.zeros((3, 3), **dd, requires_grad=True) else: field_grad = None @@ -186,9 +186,7 @@ def batched( efield_grad = new_efield_grad(field_grad) calc = Calculator(numbers, par, interaction=[efield], opts=opts, **dd) - quadrupole = calc.quadrupole( - numbers, positions, charge, use_functorch=use_functorch - ) + quadrupole = calc.quadrupole(numbers, pos, charge, use_functorch=use_functorch) quadrupole.detach_() assert pytest.approx(ref, abs=atol, rel=rtol) == quadrupole @@ -323,13 +321,13 @@ def test_batch_settings( # required for autodiff of energy w.r.t. efield and quadrupole field_vector.requires_grad_(True) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) efield = new_efield(field_vector) options = dict(opts, **{"scp_mode": scp_mode, "mixer": mixer}) calc = Calculator(numbers, par, interaction=[efield], opts=options, **dd) - quadrupole = calc.quadrupole(numbers, positions, charge) + quadrupole = calc.quadrupole(numbers, pos, charge) quadrupole.detach_() assert pytest.approx(ref, abs=1e-4) == quadrupole @@ -370,7 +368,7 @@ def test_batch_unconverged(dtype: torch.dtype, name1: str, name2: str) -> None: # required for autodiff of energy w.r.t. efield and quadrupole field_vector.requires_grad_(True) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) # with 5 iterations, both do not converge, but pass the test options = dict(opts, **{"maxiter": 5, "mixer": "simple"}) @@ -378,7 +376,7 @@ def test_batch_unconverged(dtype: torch.dtype, name1: str, name2: str) -> None: efield = new_efield(field_vector) calc = Calculator(numbers, par, interaction=[efield], opts=options, **dd) - quadrupole = calc.quadrupole(numbers, positions, charge) + quadrupole = calc.quadrupole(numbers, pos, charge) quadrupole.detach_() assert pytest.approx(ref, abs=1e-2, rel=1e-3) == quadrupole @@ -419,7 +417,7 @@ def test_batch_unconverged(dtype: torch.dtype, name1: str, name2: str) -> None: # required for autodiff of energy w.r.t. efield and quadrupole field_vector.requires_grad_(True) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) # with 5 iterations, both do not converge, but pass the test options = dict(opts, **{"maxiter": 5, "mixer": "simple"}) @@ -427,7 +425,7 @@ def test_batch_unconverged(dtype: torch.dtype, name1: str, name2: str) -> None: efield = new_efield(field_vector) calc = Calculator(numbers, par, interaction=[efield], opts=options, **dd) - quadrupole = calc.quadrupole(numbers, positions, charge) + quadrupole = calc.quadrupole(numbers, pos, charge) quadrupole.detach_() assert pytest.approx(ref, abs=1e-2, rel=1e-3) == quadrupole diff --git a/test/test_repulsion/test_grad_pos.py b/test/test_repulsion/test_grad_pos.py index abf77b757..c9f25306f 100644 --- a/test/test_repulsion/test_grad_pos.py +++ b/test/test_repulsion/test_grad_pos.py @@ -64,17 +64,17 @@ def test_backward_vs_tblite(dtype: torch.dtype, name: str) -> None: cache = rep.get_cache(numbers, ihelp) # automatic gradient - positions.requires_grad_(True) - energy = torch.sum(rep.get_energy(positions, cache), dim=-1) + pos = positions.clone().requires_grad_(True) + energy = torch.sum(rep.get_energy(pos, cache), dim=-1) energy.backward() - assert positions.grad is not None - grad_backward = positions.grad.clone() + assert pos.grad is not None + grad_backward = pos.grad.clone() # also zero out gradients when using `.backward()` grad_backward.detach_() - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol) == grad_backward.cpu() @@ -114,17 +114,17 @@ def test_backward_batch_vs_tblite(dtype: torch.dtype, name1: str, name2: str) -> cache = rep.get_cache(numbers, ihelp) # automatic gradient - positions.requires_grad_(True) - energy = torch.sum(rep.get_energy(positions, cache)) + pos = positions.clone().requires_grad_(True) + energy = torch.sum(rep.get_energy(pos, cache)) energy.backward() - assert positions.grad is not None - grad_backward = positions.grad.clone() + assert pos.grad is not None + grad_backward = pos.grad.clone() # also zero out gradients when using `.backward()` grad_backward.detach_() - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol) == grad_backward.cpu() @@ -152,17 +152,17 @@ def test_grad_pos_backward_vs_analytical(dtype: torch.dtype, name: str) -> None: ) # automatic gradient - positions.requires_grad_(True) - energy = torch.sum(rep.get_energy(positions, cache), dim=-1) + pos = positions.clone().requires_grad_(True) + energy = torch.sum(rep.get_energy(pos, cache), dim=-1) energy.backward() - assert positions.grad is not None - grad_backward = positions.grad.clone() + assert pos.grad is not None + grad_backward = pos.grad.clone() # also zero out gradients when using `.backward()` grad_backward.detach_() - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(grad_analytical.cpu(), abs=tol) == grad_backward.cpu() @@ -240,12 +240,12 @@ def gradchecker( cache = rep.get_cache(numbers, ihelp) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - return rep.get_energy(pos, cache) + def func(p: Tensor) -> Tensor: + return rep.get_energy(p, cache) - return func, positions + return func, pos @pytest.mark.grad @@ -299,12 +299,12 @@ def gradchecker_batch( cache = rep.get_cache(numbers, ihelp) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - return rep.get_energy(pos, cache) + def func(p: Tensor) -> Tensor: + return rep.get_energy(p, cache) - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_repulsion/test_hess.py b/test/test_repulsion/test_hess.py index 6f6aa2f77..b4cfcaaf1 100644 --- a/test/test_repulsion/test_hess.py +++ b/test/test_repulsion/test_hess.py @@ -54,7 +54,7 @@ def test_single(dtype: torch.dtype, name: str) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) rep = new_repulsion(numbers, par, **dd) assert rep is not None @@ -65,10 +65,10 @@ def test_single(dtype: torch.dtype, name: str) -> None: def energy(pos: Tensor) -> Tensor: return rep.get_energy(pos, cache).sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() + pos.detach_() hess = hess.detach().reshape_as(ref) assert ref.shape == hess.shape @@ -124,13 +124,12 @@ def skip_test_batch(dtype: torch.dtype, name1: str, name2) -> None: ihelp = IndexHelper.from_numbers(numbers, par) cache = rep.get_cache(numbers, ihelp) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def energy(pos: Tensor) -> Tensor: - return rep.get_energy(pos, cache).sum() + def energy(p: Tensor) -> Tensor: + return rep.get_energy(p, cache).sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == hess.detach().cpu() diff --git a/test/test_scf/skip_test_grad_pos.py b/test/test_scf/skip_test_grad_pos.py index 3281330ed..7efc0c95c 100644 --- a/test/test_scf/skip_test_grad_pos.py +++ b/test/test_scf/skip_test_grad_pos.py @@ -61,13 +61,13 @@ def gradchecker( calc = Calculator(numbers, par, opts=options, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - result = calc.singlepoint(pos, charges) + def func(p: Tensor) -> Tensor: + result = calc.singlepoint(p, charges) return result.total - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_scf/test_charged.py b/test/test_scf/test_charged.py index abac618e9..305309287 100644 --- a/test/test_scf/test_charged.py +++ b/test/test_scf/test_charged.py @@ -76,7 +76,7 @@ def test_grad(dtype: torch.dtype, name: str): sample = samples[name] numbers = sample["numbers"].to(DEVICE) positions = samples[name]["positions"].to(**dd).detach() - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) chrg = sample["charge"].to(**dd) # Values obtain with tblite 0.2.1 disabling repulsion and dispersion @@ -92,8 +92,8 @@ def test_grad(dtype: torch.dtype, name: str): }, ) calc = Calculator(numbers, par, opts=options, **dd) - result = calc.singlepoint(positions, chrg) + result = calc.singlepoint(pos, chrg) energy = result.scf.sum(-1) - (gradient,) = torch.autograd.grad(energy, positions) + (gradient,) = torch.autograd.grad(energy, pos) assert pytest.approx(gradient.cpu(), abs=tol, rel=1e-5) == ref.cpu() diff --git a/test/test_scf/test_grad.py b/test/test_scf/test_grad.py index 4171cd311..f1f83f5f5 100644 --- a/test/test_scf/test_grad.py +++ b/test/test_scf/test_grad.py @@ -80,7 +80,7 @@ def run_grad_backwards( numbers = samples[name]["numbers"].to(DEVICE) positions = samples[name]["positions"].to(**dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) charges = torch.tensor(0.0, **dd) # Values obtained with tblite 0.2.1 disabling repulsion and dispersion @@ -96,17 +96,17 @@ def run_grad_backwards( ) calc = Calculator(numbers, par, opts=options, **dd) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) energy.backward() - assert positions.grad is not None + assert pos.grad is not None - gradient = positions.grad.clone() + gradient = pos.grad.clone() # also zero out gradients when using `.backward()` - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == gradient.cpu() @@ -130,7 +130,7 @@ def run_grad_autograd(name: str, dtype: torch.dtype): numbers = samples[name]["numbers"].to(DEVICE) positions = samples[name]["positions"].to(**dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) charges = torch.tensor(0.0, **dd) # Values obtained with tblite 0.2.1 disabling repulsion and dispersion @@ -144,16 +144,14 @@ def run_grad_autograd(name: str, dtype: torch.dtype): options = dict(opts, **{"f_atol": tol, "x_atol": tol}) calc = Calculator(numbers, par, opts=options, **dd) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) - (gradient,) = torch.autograd.grad(energy, positions) + (gradient,) = torch.autograd.grad(energy, pos) assert pytest.approx(gradient.cpu(), abs=tol, rel=1e-5) == ref.cpu() assert pytest.approx(gradient.cpu(), abs=tol, rel=1e-5) == ref_full.cpu() - positions.detach_() - # FIXME: fails for LYS_xao_dist @pytest.mark.grad @@ -178,21 +176,19 @@ def test_grad_large(name: str, dtype: torch.dtype): assert pytest.approx(ref_full.cpu(), abs=tol, rel=1e-5) == ref.cpu() # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) options = dict(opts, **{"f_atol": tol**2, "x_atol": tol**2}) calc = Calculator(numbers, par, opts=options, **dd) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) - (gradient,) = torch.autograd.grad(energy, positions) + (gradient,) = torch.autograd.grad(energy, pos) assert pytest.approx(gradient.cpu(), abs=tol, rel=1e-5) == ref.cpu() assert pytest.approx(gradient.cpu(), abs=tol, rel=1e-5) == ref_full.cpu() - positions.detach_() - @pytest.mark.grad @pytest.mark.parametrize("name", ["LiH"]) @@ -221,19 +217,19 @@ def run_param_grad_energy(name: str, dtype: torch.dtype = torch.float): numbers = samples[name]["numbers"].to(DEVICE) positions = samples[name]["positions"].to(**dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) charges = torch.tensor(0.0, **dd) options = dict(opts, **{"f_atol": tol**2, "x_atol": tol**2}) calc = Calculator(numbers, par, opts=options, **dd) assert calc.integrals.hcore is not None - h = calc.integrals.hcore.integral + h = calc.integrals.hcore h.selfenergy.requires_grad_(True) h.kcn.requires_grad_(True) h.shpoly.requires_grad_(True) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) pgrad = torch.autograd.grad( @@ -248,8 +244,6 @@ def run_param_grad_energy(name: str, dtype: torch.dtype = torch.float): ref_shpoly = load_from_npz(ref_grad_param, f"{name}_egrad_shpoly", dtype) assert pytest.approx(pgrad[2].cpu(), abs=tol) == ref_shpoly.cpu() - positions.detach_() - # FIXME! @pytest.mark.grad @@ -264,25 +258,25 @@ def skip_test_param_grad_force(name: str, dtype: torch.dtype = torch.float): numbers = samples[name]["numbers"].to(DEVICE) positions = samples[name]["positions"].to(**dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) charges = torch.tensor(0.0, **dd) options = dict(opts, **{"f_atol": tol**2, "x_atol": tol**2}) calc = Calculator(numbers, par, opts=options, **dd) assert calc.integrals.hcore is not None - h = calc.integrals.hcore.integral + h = calc.integrals.hcore h.selfenergy.requires_grad_(True) h.kcn.requires_grad_(True) h.shpoly.requires_grad_(True) - result = calc.singlepoint(positions, charges) + result = calc.singlepoint(pos, charges) energy = result.scf.sum(-1) (gradient,) = torch.autograd.grad( energy, - positions, + pos, create_graph=True, ) @@ -297,5 +291,3 @@ def skip_test_param_grad_force(name: str, dtype: torch.dtype = torch.float): assert pytest.approx(pgrad[1].cpu(), abs=tol) == ref_kcn.cpu() ref_shpoly = load_from_npz(ref_grad_param, f"{name}_ggrad_shpoly", dtype) assert pytest.approx(pgrad[2].cpu(), abs=tol) == ref_shpoly.cpu() - - positions.detach_() diff --git a/test/test_scf/test_guess_grad.py b/test/test_scf/test_guess_grad.py index 9841f1512..bc8a32440 100644 --- a/test/test_scf/test_guess_grad.py +++ b/test/test_scf/test_guess_grad.py @@ -53,12 +53,12 @@ def gradchecker( ihelp = IndexHelper.from_numbers(numbers, par) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - return guess.get_guess(numbers, pos, charge, ihelp, name=guess_name) + def func(p: Tensor) -> Tensor: + return guess.get_guess(numbers, p, charge, ihelp, name=guess_name) - return func, positions + return func, pos @pytest.mark.grad @@ -111,12 +111,12 @@ def gradchecker_batch( ihelp = IndexHelper.from_numbers(numbers, par) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - return guess.get_guess(numbers, pos, charge, ihelp, name=guess_name) + def func(p: Tensor) -> Tensor: + return guess.get_guess(numbers, p, charge, ihelp, name=guess_name) - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_scf/test_hess.py b/test/test_scf/test_hess.py index cab7a8e3b..e0ce10061 100644 --- a/test/test_scf/test_hess.py +++ b/test/test_scf/test_hess.py @@ -63,20 +63,22 @@ def test_single(dtype: torch.dtype, name: str) -> None: calc = Calculator(numbers, par, opts=opts, **dd) + pos = positions.clone() + # numerical hessian - numref = _numhess(calc, numbers, positions, charge) + numref = _numhess(calc, numbers, pos, charge) # variable to be differentiated - positions.requires_grad_(True) + pos.requires_grad_(True) def energy(pos: Tensor) -> Tensor: result = calc.singlepoint(pos, charge) return result.scf.sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() + pos.detach_() hess = hess.detach().reshape_as(ref) numref = numref.reshape_as(ref) @@ -95,10 +97,10 @@ def _numhess( **{"device": positions.device, "dtype": positions.dtype}, ) - def _gradfcn(positions: Tensor, charge: Tensor) -> Tensor: - positions.requires_grad_(True) - result = -calc.forces_analytical(positions, charge) - positions.detach_() + def _gradfcn(pos: Tensor, charge: Tensor) -> Tensor: + pos.requires_grad_(True) + result = -calc.forces_analytical(pos, charge) + pos.detach_() return result.detach() step = 1.0e-4 diff --git a/test/test_singlepoint/test_grad_pos_withfield.py b/test/test_singlepoint/test_grad_pos_withfield.py index 3fea66087..732af0b33 100644 --- a/test/test_singlepoint/test_grad_pos_withfield.py +++ b/test/test_singlepoint/test_grad_pos_withfield.py @@ -68,14 +68,14 @@ def gradchecker(dtype: torch.dtype, name: str) -> tuple[ calc = Calculator(numbers, par, interaction=[efield], opts=opts, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - result = calc.singlepoint(pos, charge) + def func(p: Tensor) -> Tensor: + result = calc.singlepoint(p, charge) energy = result.total.sum(-1) return energy - return func, positions + return func, pos @pytest.mark.grad @@ -130,14 +130,14 @@ def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ calc = Calculator(numbers, par, interaction=[efield], opts=opts, **dd) # variables to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(pos: Tensor) -> Tensor: - result = calc.singlepoint(pos, charge) + def func(p: Tensor) -> Tensor: + result = calc.singlepoint(p, charge) energy = result.total.sum(-1) return energy - return func, positions + return func, pos @pytest.mark.grad diff --git a/test/test_singlepoint/test_hess.py b/test/test_singlepoint/test_hess.py index 8b213be7c..30d9a83d4 100644 --- a/test/test_singlepoint/test_hess.py +++ b/test/test_singlepoint/test_hess.py @@ -72,19 +72,17 @@ def test_single(dtype: torch.dtype, name: str) -> None: ) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) options = dict(opts, **{"exclude": ["scf"]}) calc = Calculator(numbers, par, opts=options, **dd) - def energy(pos: Tensor) -> Tensor: - result = calc.singlepoint(pos, charge) + def energy(p: Tensor) -> Tensor: + result = calc.singlepoint(p, charge) return result.total.sum() - hess = jacrev(jacrev(energy))(positions) + hess = jacrev(jacrev(energy))(pos) assert isinstance(hess, Tensor) - positions.detach_() hess = hess.detach().reshape_as(ref) - assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == hess.cpu() diff --git a/test/test_solvation/test_born.py b/test/test_solvation/test_born.py index 76a4ad4d3..385fe0b19 100644 --- a/test/test_solvation/test_born.py +++ b/test/test_solvation/test_born.py @@ -180,12 +180,12 @@ def test_psi_grad(name: str): rvdw = VDW_D3.to(**dd)[numbers] # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(positions: Tensor): - return born.compute_psi(numbers, positions, rvdw) + def func(p: Tensor): + return born.compute_psi(numbers, p, rvdw) - assert dgradcheck(func, positions) + assert dgradcheck(func, pos) @pytest.mark.grad @@ -200,9 +200,9 @@ def test_radii_grad(name: str): positions = sample["positions"].to(**dd) # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) - def func(positions: Tensor): - return born.get_born_radii(numbers, positions) + def func(p: Tensor): + return born.get_born_radii(numbers, p) - assert dgradcheck(func, positions) + assert dgradcheck(func, pos) diff --git a/test/test_solvation/test_grad.py b/test/test_solvation/test_grad.py index 894f64ed9..54060ef73 100644 --- a/test/test_solvation/test_grad.py +++ b/test/test_solvation/test_grad.py @@ -51,7 +51,7 @@ def test_gb_scf_grad(dtype: torch.dtype, name: str, dielectric_constant=78.9): sample = samples[name] numbers = sample["numbers"].to(DEVICE) positions = sample["positions"].to(**dd) - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) ref = sample["gradient"] charges = torch.tensor(0.0).type(dtype) @@ -60,16 +60,16 @@ def test_gb_scf_grad(dtype: torch.dtype, name: str, dielectric_constant=78.9): calc = Calculator(numbers, par, interaction=[gb], opts=opts, **dd) - results = calc.singlepoint(positions, charges) + results = calc.singlepoint(pos, charges) energy = results.scf.sum(-1) # autograd energy.backward() - assert positions.grad is not None - autograd = positions.grad.clone() + assert pos.grad is not None + autograd = pos.grad.clone() # also zero out gradients when using `.backward()` - positions.detach_() - positions.grad.data.zero_() + pos.detach_() + pos.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol) == autograd.cpu() diff --git a/tox.ini b/tox.ini index 16d50c4e1..d7d82f032 100644 --- a/tox.ini +++ b/tox.ini @@ -18,11 +18,11 @@ min_version = 4.0 isolated_build = True envlist = - py38-torch{1110,1121,1131,201,212,222}, - py39-torch{1110,1121,1131,201,212,222}, - py310-torch{1110,1121,1131,201,212,222}, - py311-torch{1131,201,212,222} - py312-torch{222,231} + py38-torch{1110,1121,1131,201,212,222,231,240}, + py39-torch{1110,1121,1131,201,212,222,231,240}, + py310-torch{1110,1121,1131,201,212,222,231,240}, + py311-torch{1131,201,212,222,231,240} + py312-torch{222,231,240} [testenv] setenv = @@ -43,6 +43,7 @@ deps = torch222: torch==2.2.2 torch230: torch==2.3.0 torch231: torch==2.3.1 + torch240: torch==2.4.0 .[tox] commands = pytest -vv {posargs: \