Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Aug 24, 2024
1 parent 71f1d01 commit cae90ee
Show file tree
Hide file tree
Showing 20 changed files with 615 additions and 42 deletions.
16 changes: 16 additions & 0 deletions src/dxtb/_src/calculators/types/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,31 @@ def singlepoint(

if kwargs.get("store_hcore", copts.hcore):
self.cache["hcore"] = self.integrals.hcore
else:
if self.integrals.hcore is not None:
if self.integrals.hcore.requires_grad is False:
self.integrals.hcore.clear()

if kwargs.get("store_overlap", copts.overlap):
self.cache["overlap"] = self.integrals.overlap
else:
if self.integrals.overlap is not None:
if self.integrals.overlap.requires_grad is False:
self.integrals.overlap.clear()

if kwargs.get("store_dipole", copts.dipole):
self.cache["dipint"] = self.integrals.dipole
else:
if self.integrals.dipole is not None:
if self.integrals.dipole.requires_grad is False:
self.integrals.dipole.clear()

if kwargs.get("store_quadrupole", copts.quadrupole):
self.cache["quadint"] = self.integrals.quadrupole
else:
if self.integrals.quadrupole is not None:
if self.integrals.quadrupole.requires_grad is False:
self.integrals.quadrupole.clear()

self._ncalcs += 1
return result
Expand Down
8 changes: 4 additions & 4 deletions src/dxtb/_src/components/interactions/coulomb/secondorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
# Define atomic numbers, positions, and charges
numbers = torch.tensor([14, 1, 1, 1, 1])
positions = torch.tensor([
[0.00000000000000, -0.00000000000000, 0.00000000000000],
[1.61768389755830, 1.61768389755830, -1.61768389755830],
[+0.00000000000000, -0.00000000000000, +0.00000000000000],
[+1.61768389755830, +1.61768389755830, -1.61768389755830],
[-1.61768389755830, -1.61768389755830, -1.61768389755830],
[1.61768389755830, -1.61768389755830, 1.61768389755830],
[-1.61768389755830, 1.61768389755830, 1.61768389755830],
[+1.61768389755830, -1.61768389755830, +1.61768389755830],
[-1.61768389755830, +1.61768389755830, +1.61768389755830],
])
q = torch.tensor([
Expand Down
8 changes: 4 additions & 4 deletions src/dxtb/_src/components/interactions/coulomb/thirdorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
# Define atomic numbers and their positions
numbers = torch.tensor([14, 1, 1, 1, 1])
positions = torch.tensor([
[0.00000000000000, -0.00000000000000, 0.00000000000000],
[1.61768389755830, 1.61768389755830, -1.61768389755830],
[+0.00000000000000, -0.00000000000000, +0.00000000000000],
[+1.61768389755830, +1.61768389755830, -1.61768389755830],
[-1.61768389755830, -1.61768389755830, -1.61768389755830],
[1.61768389755830, -1.61768389755830, 1.61768389755830],
[-1.61768389755830, 1.61768389755830, 1.61768389755830],
[+1.61768389755830, -1.61768389755830, +1.61768389755830],
[-1.61768389755830, +1.61768389755830, +1.61768389755830],
])
# Atomic charges
Expand Down
8 changes: 4 additions & 4 deletions src/dxtb/_src/components/interactions/solvation/alpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
numbers = torch.tensor([14, 1, 1, 1, 1])
positions = torch.tensor([
[0.00000000000000, -0.00000000000000, 0.00000000000000],
[1.61768389755830, 1.61768389755830, -1.61768389755830],
[+0.00000000000000, -0.00000000000000, +0.00000000000000],
[+1.61768389755830, +1.61768389755830, -1.61768389755830],
[-1.61768389755830, -1.61768389755830, -1.61768389755830],
[1.61768389755830, -1.61768389755830, 1.61768389755830],
[-1.61768389755830, 1.61768389755830, 1.61768389755830],
[+1.61768389755830, -1.61768389755830, +1.61768389755830],
[-1.61768389755830, +1.61768389755830, +1.61768389755830],
])
charges = torch.tensor([
-8.41282505804719e-2,
Expand Down
8 changes: 4 additions & 4 deletions src/dxtb/_src/components/interactions/solvation/born.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
# Define atomic numbers and positions of the atoms
numbers = torch.tensor([14, 1, 1, 1, 1])
positions = torch.tensor([
[0.00000000000000, -0.00000000000000, 0.00000000000000],
[1.61768389755830, 1.61768389755830, -1.61768389755830],
[+0.00000000000000, -0.00000000000000, +0.00000000000000],
[+1.61768389755830, +1.61768389755830, -1.61768389755830],
[-1.61768389755830, -1.61768389755830, -1.61768389755830],
[1.61768389755830, -1.61768389755830, 1.61768389755830],
[-1.61768389755830, 1.61768389755830, 1.61768389755830],
[+1.61768389755830, -1.61768389755830, +1.61768389755830],
[-1.61768389755830, +1.61768389755830, +1.61768389755830],
])
# Calculate the Born radii for the given atomic configuration
Expand Down
33 changes: 24 additions & 9 deletions src/dxtb/_src/integral/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,12 @@ def setup(self, positions: Tensor, **kwargs) -> None:
"""

def __str__(self) -> str:
dict_repr = []
for key, value in self.__dict__.items():
if isinstance(value, Tensor):
value_repr = f"{value.shape}"
else:
value_repr = repr(value)
dict_repr.append(f" {key}: {value_repr}")
dict_str = "{\n" + ",\n".join(dict_repr) + "\n}"
return f"{self.__class__.__name__}({dict_str})"
return (
f"{self.__class__.__name__}("
f"Family: {self.family}, "
f"Number of Atoms: {self.numbers.shape[-1]}, "
f"Setup?: {self.is_setup()})"
)

def __repr__(self) -> str:
return str(self)
Expand Down Expand Up @@ -328,6 +325,22 @@ def clear(self) -> None:
self._norm = None
self._gradient = None

@property
def requires_grad(self) -> bool:
"""
Check if any field of the integral class is requires gradient.
Returns
-------
bool
Flag for gradient requirement.
"""
for field in (self._matrix, self._gradient, self._norm):
if field is not None and field.requires_grad:
return True

return False

def normalize(self, norm: Tensor | None = None) -> None:
"""
Normalize the integral (changes ``self.matrix``).
Expand Down Expand Up @@ -401,6 +414,8 @@ def __str__(self) -> str:
d["_matrix"] = self._matrix.shape
if self._gradient is not None:
d["_gradient"] = self._gradient.shape
if self._norm is not None:
d["_norm"] = self._norm.shape

return f"{self.__class__.__name__}({d})"

Expand Down
6 changes: 4 additions & 2 deletions src/dxtb/_src/integral/driver/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def new_driver(
return new_driver_pytorch(numbers, par, device=device, dtype=dtype)

if name == labels.INTDRIVER_AUTOGRAD:
return new_driver_pytorch2(numbers, par, device=device, dtype=dtype)
return new_driver_pytorch_no_analytical(
numbers, par, device=device, dtype=dtype
)

if name == labels.INTDRIVER_LEGACY:
return new_driver_legacy(numbers, par, device=device, dtype=dtype)
Expand Down Expand Up @@ -97,7 +99,7 @@ def new_driver_pytorch(
return _IntDriver(numbers, par, ihelp, device=device, dtype=dtype)


def new_driver_pytorch2(
def new_driver_pytorch_no_analytical(
numbers: Tensor,
par: Param,
device: torch.device | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/dxtb/_src/integral/types/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def traceless(self) -> Tensor:
zx zy zz 6 7 8 6 7 8
"""

if self.matrix.shape[-3] != 9:
if self.matrix.ndim not in (3, 4) or self.matrix.shape[-3] != 9:
raise RuntimeError(
"Quadrupole integral must be a tensor tensor of shape "
f"'(9, nao, nao)' but is {self.matrix.shape}."
Expand Down
3 changes: 2 additions & 1 deletion src/dxtb/_src/integral/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@

def snorm(ovlp: Tensor) -> Tensor:
d = ovlp.diagonal(dim1=-1, dim2=-2)
return torch.where(d == 0.0, 0.0, torch.pow(d, -0.5))
zero = torch.tensor(0.0, dtype=d.dtype, device=d.device)
return torch.where(d == 0.0, zero, torch.pow(d, -0.5))
8 changes: 3 additions & 5 deletions src/dxtb/_src/integral/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@
from dxtb._src.xtb.gfn1 import GFN1Hamiltonian
from dxtb._src.xtb.gfn2 import GFN2Hamiltonian

from .driver.factory import new_driver
from .driver.manager import DriverManager
from .factory import new_dipint, new_overlap, new_quadint
from .types import DipoleIntegral, OverlapIntegral, QuadrupoleIntegral

__all__ = ["hcore", "overlap", "dipint", "quadint"]

Expand All @@ -100,20 +98,20 @@ def hcore(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tens
Raises
------
TypeError
ValueError
If the Hamiltonian parametrization does not contain meta data or if the
meta data does not contain a name to select the correct Hamiltonian.
ValueError
If the Hamiltonian name is unknown.
"""
if par.meta is None:
raise TypeError(
raise ValueError(
"Meta data of Hamiltonian parametrization must contain a name. "
"Otherwise, the correct Hamiltonian cannot be selected internally."
)

if par.meta.name is None:
raise TypeError(
raise ValueError(
"The name field of the meta data of the Hamiltonian "
"parametrization must contain a name. Otherwise, the correct "
"Hamiltonian cannot be selected internally."
Expand Down
7 changes: 7 additions & 0 deletions src/dxtb/_src/xtb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def clear(self) -> None:
"""
self._matrix = None

@property
def requires_grad(self) -> bool:
if self._matrix is None:
return False

return self._matrix.requires_grad

def get_occupation(self) -> Tensor:
"""
Obtain the reference occupation numbers for each orbital.
Expand Down
88 changes: 88 additions & 0 deletions test/test_calculator/test_cache/test_integrals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test caching integrals.
"""

from __future__ import annotations

import pytest
import torch

from dxtb._src.typing import DD, Tensor
from dxtb.calculators import GFN1Calculator

from ...conftest import DEVICE

opts = {"cache_enabled": True, "verbosity": 0}


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_overlap_deleted(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)

calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd)
assert calc._ncalcs == 0

# overlap should not be cached
assert calc.opts.cache.store.overlap == False

# one successful calculation
energy = calc.get_energy(positions)
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# cache should be empty
assert calc.cache.overlap is None

# ... but also the tensors in the calculator should be deleted
assert calc.integrals.overlap is not None
assert calc.integrals.overlap._matrix is None
assert calc.integrals.overlap._norm is None
assert calc.integrals.overlap._gradient is None


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_overlap_retained_for_grad(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd, requires_grad=True
)

calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd)
assert calc._ncalcs == 0

# overlap should not be cached
assert calc.opts.cache.store.overlap == False

# one successful calculation
energy = calc.get_energy(positions)
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# cache should still be empty ...
assert calc.cache.overlap is None

# ... but the tensors in the calculator should still be there
assert calc.integrals.overlap is not None
assert calc.integrals.overlap._matrix is not None
assert calc.integrals.overlap._norm is not None
3 changes: 2 additions & 1 deletion test/test_indexhelper/test_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def test_spread_unique_batch() -> None:
x = torch.randn((nbatch, nat_u, 3), device=DEVICE)

# pollutes CUDA memory
assert False
if DEVICE is not None:
assert False

out = ihelp.spread_uspecies_to_atom(x, dim=-2, extra=True)
assert out.shape == torch.Size((nbatch, nat, 3))
Expand Down
16 changes: 16 additions & 0 deletions test/test_integrals/test_driver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit cae90ee

Please sign in to comment.