Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batched charges #180

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions examples/issues/179/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch

import dxtb
from dxtb.typing import DD

dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}

num1 = torch.tensor(
[8, 1, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1],
device=dd["device"],
)
num2 = torch.tensor(
[8, 1, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1],
device=dd["device"],
)

pos1 = torch.tensor(
[
[-7.68281384, 1.3350934, 0.74846383],
[-5.7428588, 1.31513411, 0.36896714],
[-8.23756184, -0.19765779, 1.67193897],
[-8.13313558, 2.93710683, 1.6453921],
[-2.95915993, 1.40005084, 0.24966306],
[-2.1362031, 1.4795743, -1.38758999],
[-2.40235213, 2.84218589, 1.24419946],
[-8.2640369, 5.79677268, 2.54733192],
[-8.68767571, 7.18194193, 1.3350556],
[-9.27787497, 6.09327071, 4.03498102],
[-9.34575393, -2.54164384, 3.28062124],
[-8.59029812, -3.46388688, 4.6567765],
[-10.71898011, -3.58163572, 2.65211723],
[-9.5591796, 9.66793334, -0.53212042],
[-8.70438089, 11.29169941, -0.5990394],
[-11.12723654, 9.8483266, -1.43755624],
[-2.69970054, 5.55135395, 2.96084179],
[-1.59244386, 6.50972855, 4.06699298],
[-4.38439138, 6.18065165, 3.1939773],
],
**dd
)

pos2 = torch.tensor(
[
[-7.67436676, 1.33433562, 0.74512468],
[-5.75285545, 1.30220838, 0.37189432],
[-8.23155251, -0.20308887, 1.67397231],
[-8.15184386, 2.94589406, 1.6474141],
[-2.96380866, 1.39739578, 0.24572676],
[-2.14413995, 1.48993378, -1.37321106],
[-2.39808135, 2.86614761, 1.25247646],
[-8.26855335, 5.79452391, 2.54948621],
[-8.69277797, 7.18061912, 1.33247046],
[-9.28819287, 6.08797948, 4.03809906],
[-9.3377226, -2.54245643, 3.27861813],
[-8.59693106, -3.48501402, 4.65503795],
[-10.72627446, -3.59514726, 2.66139579],
[-9.55955755, 9.6716561, -0.53106973],
[-8.7077635, 11.28708848, -0.59527696],
[-11.12540351, 9.87000175, -1.44181568],
[-2.70194931, 5.55490663, 2.9641866],
[-1.60305656, 6.49854138, 4.07984311],
[-4.39083534, 6.17898869, 3.18702311],
],
**dd
)

charge1 = torch.tensor(1, **dd)
charge2 = torch.tensor(1, **dd)


##############################################################################


numbers = torch.stack([num1, num2])
positions = torch.stack([pos1, pos2])
charge = torch.tensor([charge1, charge2])

# no conformers -> batched mode 1
opts = {"verbosity": 0, "batch_mode": 1}

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
result = calc.energy(positions, chrg=charge)


##############################################################################


calc = dxtb.Calculator(num1, dxtb.GFN1_XTB, opts={"verbosity": 0}, **dd)
result1 = calc.energy(pos1, chrg=charge1)


##############################################################################


calc = dxtb.Calculator(num2, dxtb.GFN1_XTB, opts={"verbosity": 0}, **dd)
result2 = calc.energy(pos2, chrg=charge2)


##############################################################################


assert torch.allclose(result[0], result1)
assert torch.allclose(result[1], result2)

print("Issue 179 is fixed!")
3 changes: 3 additions & 0 deletions src/dxtb/_src/calculators/types/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dxtb._src.utils.tensors import tensor_id

from ..result import Result
from ..utils import shape_checks_chrg
from . import decorators as cdec
from .base import BaseCalculator

Expand Down Expand Up @@ -116,6 +117,8 @@ def singlepoint(
if spin is not None:
spin = any_to_tensor(spin, **self.dd)

assert shape_checks_chrg(chrg, self.numbers.ndim, name="Charge")

result = Result(positions, **self.dd)

###########################
Expand Down
73 changes: 73 additions & 0 deletions src/dxtb/_src/calculators/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Calculators: Utility
====================

Collection of utility functions for the calculator.
"""

from __future__ import annotations

from dxtb._src.typing import Literal, NoReturn, Tensor

__all__ = ["shape_checks_chrg"]


def shape_checks_chrg(
t: Tensor, ndims: int, name: str = "Charge"
) -> Literal[True] | NoReturn:
"""
Check the shape of a tensor.

Parameters
----------
t : Tensor
The tensor to check.
ndims : int
The number of dimensions indicating single or batched calculations.

Raises
------
ValueError
If the tensor has not the expected number of dimensions.
"""

if t.ndim > 1:
raise ValueError(
f"{name.title()} tensor has more than 1 dimension. "
"Please use a 1D tensor for batched calculations "
"(e.g., `torch.tensor([1.0, 0.0])`), instead of "
"a 2D tensor (e.g., NOT `torch.tensor([[1.0], [0.0]])`)."
)

if t.ndim == 1 and t.numel() == 1:
raise ValueError(
f"{name.title()} tensor has only one element. Please use a "
"scalar for single structures (e.g., `torch.tensor(1.0)`) and "
"a 1D tensor for batched calculations (e.g., "
)

if ndims != t.ndim + 1:
raise ValueError(
f"{name.title()} tensor has invalid shape: {t.shape}.\n "
"Please use a scalar for single structures (e.g., "
"`torch.tensor(1.0)`) and a 1D tensor for batched "
"calculations (e.g., `torch.tensor([1.0, 0.0])`)."
)

return True
2 changes: 1 addition & 1 deletion src/dxtb/_src/components/classicals/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_energy(
if len(self.components) <= 0:
return {"none": positions.new_zeros(positions.shape[:-1])}

energies = {}
energies: dict[str, Tensor] = {}
for classical in self.components:
timer.start(classical.label, parent_uid="Classicals")
energies[classical.label] = classical.get_energy(
Expand Down
5 changes: 4 additions & 1 deletion src/dxtb/_src/scf/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tad_mctc import storch

from dxtb import IndexHelper
from dxtb._src.calculators.utils import shape_checks_chrg
from dxtb._src.components.interactions import (
InteractionList,
InteractionListCache,
Expand Down Expand Up @@ -176,8 +177,10 @@ def get_refocc(
torch.tensor(0, device=refs.device, dtype=refs.dtype),
)

assert shape_checks_chrg(chrg, n0.ndim, name="Charge")

# Obtain the reference occupation and total number of electrons
nel = torch.sum(n0, -1) - torch.sum(chrg, -1)
nel = torch.sum(n0, -1) - chrg

# get alpha and beta electrons and occupation
nab = filling.get_alpha_beta_occupation(nel, spin)
Expand Down
48 changes: 46 additions & 2 deletions test/test_calculator/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
from dxtb import GFN1_XTB as par
from dxtb import Calculator, labels
from dxtb._src.timing import timer
from dxtb.typing import DD

from ..conftest import DEVICE

def test_fail() -> None:
numbers = torch.tensor([6, 1, 1, 1, 1], dtype=torch.double)

@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_fail(dtype: torch.dtype) -> None:
dd: DD = {"dtype": dtype, "device": DEVICE}

numbers = torch.tensor([6, 1, 1, 1, 1], **dd)

with pytest.raises(DtypeError):
Calculator(numbers, par, opts={"verbosity": 0})
Expand All @@ -39,6 +45,44 @@ def test_fail() -> None:
timer.reset()


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_fail_charge_single(dtype: torch.dtype) -> None:
dd: DD = {"dtype": dtype, "device": DEVICE}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.zeros(2, 3, **dd)

calc = Calculator(numbers, par, opts={"verbosity": 0})

# charge must be a scalar for single structure
with pytest.raises(ValueError) as excinfo:
charge = torch.tensor([0.0], **dd)
calc.singlepoint(positions, chrg=charge)

assert "Charge tensor has only one element" in str(excinfo)


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_fail_charge_batch(dtype: torch.dtype) -> None:
dd: DD = {"dtype": dtype, "device": DEVICE}

numbers = torch.tensor([[3, 1], [3, 1]], device=DEVICE)
positions = torch.zeros(2, 2, 3, **dd)

calc = Calculator(numbers, par, opts={"verbosity": 0})
with pytest.raises(ValueError) as excinfo:
charge = torch.tensor([[0.0], [0.0]], **dd)
calc.singlepoint(positions, chrg=charge)

assert "Charge tensor has more than 1 dimension" in str(excinfo)

with pytest.raises(ValueError) as excinfo:
charge = torch.tensor(0.0, **dd)
calc.singlepoint(positions, chrg=charge)

assert "Charge tensor has invalid shape" in str(excinfo)


def run_asserts(c: Calculator, dtype: torch.dtype) -> None:
assert c.dtype == dtype
assert c.classicals.dtype == dtype
Expand Down
Loading