Skip to content

Commit

Permalink
Refactor integral code (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Aug 26, 2024
1 parent b0f948f commit e999f14
Show file tree
Hide file tree
Showing 144 changed files with 3,551 additions and 2,630 deletions.
11 changes: 5 additions & 6 deletions .github/workflows/ubuntu-pytorch-1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@ jobs:
exclude:
# Check latest versions here: https://download.pytorch.org/whl/torch/
#
# PyTorch now fully supports Python=<3.11
# see: https://github.com/pytorch/pytorch/issues/86566
# PyTorch issues:
# 3.11: https://github.com/pytorch/pytorch/issues/86566
# 3.12: https://github.com/pytorch/pytorch/issues/110436
# 3.13: https://github.com/pytorch/pytorch/issues/130249
#
# PyTorch does now support Python 3.12 (Linux) for 2.2.0 and newer
# see: https://github.com/pytorch/pytorch/issues/110436
#
# PyTorch<1.13.0 does only support Python=<3.10
# PyTorch<1.13.0 does only support Python<3.11 (Linux)
- python-version: "3.11"
torch-version: "1.11.0"
- python-version: "3.11"
Expand Down
9 changes: 5 additions & 4 deletions .github/workflows/ubuntu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ jobs:
exclude:
# Check latest versions here: https://download.pytorch.org/whl/torch/
#
# PyTorch fully supports Python=<3.11
# see: https://github.com/pytorch/pytorch/issues/86566
# PyTorch issues:
# 3.11: https://github.com/pytorch/pytorch/issues/86566
# 3.12: https://github.com/pytorch/pytorch/issues/110436
# 3.13: https://github.com/pytorch/pytorch/issues/130249
#
# PyTorch supports Python 3.12 (Linux) for 2.2.0 and newer
# see: https://github.com/pytorch/pytorch/issues/110436
# PyTorch<2.2.0 does only support Python<3.12 (all platforms)
- python-version: "3.12"
torch-version: "2.0.1"
- python-version: "3.12"
Expand Down
15 changes: 15 additions & 0 deletions docs/source/03_for_developers/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,18 @@ the error is actually raised because the mixer in `xitorch` produces a
Apparently, this occurs if the convergence criteria are too strict in single
precision. The solution is to increase the convergence criteria to more than
1e-6.
RuntimeError: clone is not supported by NestedIntSymNode
--------------------------------------------------------
This is a bug in PyTorch 2.3.0 and 2.3.1 (see
`PyTorch #128607 <<https://github.com/pytorch/pytorch/issues/128607>`__).
To avoid this error, manually import `torch._dynamo` in the code. For example:
.. code-block:: python
from tad_mctc._version import __tversion__
if __tversion__ in ((2, 3, 0), (2, 3, 1)):
import torch._dynamo
13 changes: 11 additions & 2 deletions examples/limitation_xitorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,30 @@
from pathlib import Path

import torch
from tad_mctc._version import __tversion__
from tad_mctc.io import read

import dxtb
from dxtb.typing import DD

if __tversion__ in ((2, 3, 0), (2, 3, 1)):
import torch._dynamo


dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}

f = Path(__file__).parent / "molecules" / "lih.xyz"
numbers, positions = read.read_from_path(f, **dd)
charge = read.read_chrg_from_path(f, **dd)

opts = {"verbosity": 3, "scf_mode": "nonpure"}
opts = {"verbosity": 0, "scf_mode": "nonpure"}

######################################################################

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
pos = positions.clone().requires_grad_(True)
hess = calc.hessian(pos, chrg=charge, use_functorch=True)

try:
calc.hessian(pos, chrg=charge, use_functorch=True)
except RuntimeError as e:
print(f"RuntimeError:\n{str(e)}")
9 changes: 8 additions & 1 deletion examples/profiling/batch-vs-seq-nicotine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
dxtb.timer.cuda_sync = False

# read molecule from file
f = Path(__file__).parent / "molecules" / "nicotine.xyz"
p = Path(__file__).parent

if "molecules" not in [x.name for x in p.iterdir()]:
p = p.parent

f = p / "molecules" / "nicotine.xyz"


numbers, positions = read.read_from_path(f, **dd)
charge = read.read_chrg_from_path(f, **dd)

Expand Down
2 changes: 2 additions & 0 deletions examples/profiling/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@
print("dxtb", t2 - t1)
print("Param", t3 - t2)
print("scipy", t4 - t3)

del scipy, torch, GFN1_XTB
19 changes: 19 additions & 0 deletions examples/run-all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

set -e

# search recursively for all python files
for example in $(find . -name "*.py"); do
title="Running $example"

# match the length of the title
line=$(printf '=%.0s' $(seq 1 ${#title}))

echo "$line"
echo "$title"
echo "$line"

python3 "$example"

printf "\n\n"
done
8 changes: 4 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ install_requires =
numpy<2
pydantic>=2.0.0
scipy
tad-dftd3>=0.2.2
tad-dftd4>=0.1.0
tad-libcint>=0.0.1
tad-mctc>=0.1.5
tad-dftd3>=0.3.0
tad-dftd4>=0.2.0
tad-libcint>=0.1.0
tad-mctc>=0.2.0
tad-multicharge
tomli
tomli-w
Expand Down
20 changes: 20 additions & 0 deletions src/dxtb/_src/calculators/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ class ConfigCacheStore:
potential: bool
"""Whether to store the potential matrix."""

def set(self, key: str, value: bool) -> None:
"""
Set configuration options using keyword arguments.
Parameters
----------
key : str
The configuration key.
value : bool
The configuration value.
Example
-------
config.set("hcore", True)
"""
if not hasattr(self, key):
raise ValueError(f"Unknown configuration key: {key}")

setattr(self, key, value)


class ConfigCache:
"""
Expand Down
7 changes: 4 additions & 3 deletions src/dxtb/_src/calculators/properties/vibration/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def to_unit(self, value: Literal["freqs", "modes"], unit: str) -> Tensor:
if value == "freqs":
return self._convert(self.freqs, unit, self.converter_freqs)

# TODO: Implement conversion for normal modes (required?)
# if value == "modes":
# return self._convert(self.modes, unit, self.converter_modes)
# return self._convert(self.modes, unit, self.converter_modes)

raise ValueError(f"Unsupported value for conversion: {value}")

Expand Down Expand Up @@ -143,7 +144,7 @@ def _get_rotational_modes(mass: Tensor, positions: Tensor):
_, paxes = storch.eighb(im)

# make z-axis rotation vector with smallest moment of inertia
# w = torch.flip(w, [-1])
# _ = torch.flip(_, [-1])
paxes = torch.flip(paxes, [-1])
ex, ey, ez = paxes.mT

Expand Down Expand Up @@ -288,7 +289,7 @@ def vib_analysis(
# by 2π (ω = √λ / 2π) in order to immediately get frequencies in
# Hartree: E = hbar * ω with hbar = 1 in atomic units. Dividing by 2π
# effectively converts from angular frequency (ω) to the frequency in
# cycles per second (ν, Hz), which would requires the following
# cycles per second (ν, Hz), which would require the following
# conversion to cm^-1: 1e-2 / units.CODATA.c / units.AU2SECOND.
sgn = torch.sign(force_const_au)
freqs_au = torch.sqrt(torch.abs(force_const_au)) * sgn
Expand Down
15 changes: 11 additions & 4 deletions src/dxtb/_src/calculators/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ class **inherits from all types**, i.e., it provides the energy and properties
if we need to differentiate for multiple properties at once (e.g., Hessian and
dipole moment for IR spectra). Hence, the default is ``use_functorch=False``.
"""
from .analytical import *
from .autograd import *
from .energy import *
from .numerical import *
from .analytical import AnalyticalCalculator
from .autograd import AutogradCalculator
from .energy import EnergyCalculator
from .numerical import NumericalCalculator

__all__ = [
"AnalyticalCalculator",
"AutogradCalculator",
"EnergyCalculator",
"NumericalCalculator",
]
27 changes: 17 additions & 10 deletions src/dxtb/_src/calculators/types/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forces_analytical(
timer.stop("ograd")

overlap = self.cache["overlap"]
assert isinstance(overlap, ints.types.Overlap)
assert isinstance(overlap, ints.types.OverlapIntegral)
assert overlap.matrix is not None

# density matrix
Expand Down Expand Up @@ -200,7 +200,7 @@ def forces_analytical(
assert self.integrals.hcore is not None

cn = ncoord.cn_d3(self.numbers, positions)
dedcn, dedr = self.integrals.hcore.integral.get_gradient(
dedcn, dedr = self.integrals.hcore.get_gradient(
positions,
overlap.matrix,
overlap_grad,
Expand Down Expand Up @@ -410,7 +410,7 @@ def _forces_analytical(
self.ihelp,
self.opts.scf,
intmats,
self.integrals.hcore.integral.refocc,
self.integrals.hcore.refocc,
)

timer.stop("SCF")
Expand Down Expand Up @@ -462,7 +462,7 @@ def _forces_analytical(
)

cn = ncoord.cn_d3(self.numbers, positions)
dedcn, dedr = self.integrals.hcore.integral.get_gradient(
dedcn, dedr = self.integrals.hcore.get_gradient(
positions,
intmats.overlap,
overlap_grad,
Expand Down Expand Up @@ -524,6 +524,9 @@ def dipole_analytical(
Tensor
Electric dipole moment of shape `(..., 3)`.
"""
# require caching for analytical calculation at end of function
kwargs["store_dipole"] = True

# run single point and check if integral is populated
result = self.singlepoint(positions, chrg, spin, **kwargs)

Expand All @@ -534,19 +537,23 @@ def dipole_analytical(
f"be added automatically if the '{efield.LABEL_EFIELD}' "
"interaction is added to the Calculator."
)
if dipint.matrix is None:

# Use try except to raise more informative error message, because
# `dipint.matrix` already raises a RuntimeError if the matrix is None.
try:
_ = dipint.matrix
except RuntimeError as e:
raise RuntimeError(
"Dipole moment requires a dipole integral. They should "
f"be added automatically if the '{efield.LABEL_EFIELD}' "
"interaction is added to the Calculator."
)
"interaction is added to the Calculator. This is "
"probably a bug. Check the cache setup.\n\n"
f"Original error: {str(e)}"
) from e

# pylint: disable=import-outside-toplevel
from ..properties.moments.dip import dipole

# dip = properties.dipole(
# numbers, positions, result.density, self.integrals.dipole
# )
qat = self.ihelp.reduce_orbital_to_atom(result.charges.mono)
dip = dipole(qat, positions, result.density, dipint.matrix)
return dip
Expand Down
Loading

0 comments on commit e999f14

Please sign in to comment.