From 87eb3815d26d85dd0ba7ddd47477107db4cd3742 Mon Sep 17 00:00:00 2001 From: Johannes Steinmetzer Date: Thu, 5 Sep 2024 11:44:23 +0200 Subject: [PATCH] add: numba_multipole backend --- pysisyphus/wavefunction/backend_numba.py | 2 +- .../wavefunction/backend_numba_multipole.py | 298 ++++++++++++++++++ pysisyphus/wavefunction/shells.py | 45 ++- 3 files changed, 335 insertions(+), 10 deletions(-) create mode 100644 pysisyphus/wavefunction/backend_numba_multipole.py diff --git a/pysisyphus/wavefunction/backend_numba.py b/pysisyphus/wavefunction/backend_numba.py index 8d3fb24120..5bb1b33230 100644 --- a/pysisyphus/wavefunction/backend_numba.py +++ b/pysisyphus/wavefunction/backend_numba.py @@ -237,7 +237,7 @@ def get_1el_ints_cart( ): org_components = components components = max(1, components) - symmetric = shells_b == None + symmetric = shells_b is None if symmetric: shells_b = shells_a diff --git a/pysisyphus/wavefunction/backend_numba_multipole.py b/pysisyphus/wavefunction/backend_numba_multipole.py new file mode 100644 index 0000000000..35d79819ad --- /dev/null +++ b/pysisyphus/wavefunction/backend_numba_multipole.py @@ -0,0 +1,298 @@ +import numba +import numpy as np + + +_FUNC_DATA = { + # Le, components + "int1e_ovlp": (0, 0), + "int1e_r": (1, 3), + "int1e_rr": (2, 9), +} + + +def get_func_data(key): + # Le, components + return _FUNC_DATA[key] + + +@numba.jit(nopython=True, cache=True) +def canonical_order(L: int) -> np.ndarray: + inds = np.zeros(((L + 2) * (L + 1) // 2, 3), dtype=np.int64) + i = 0 + for j in range(L + 1): + l = L - j + for n in range(j + 1): + m = j - n + inds[i] = (l, m, n) + i += 1 + return inds + + +@numba.jit(nopython=True, cache=True) +def factorial2(n: int) -> int: + """Double factorial for positive integer arguments and 0 and -1.""" + if n == -1: + return 1 + offset = n % 2 + result = 1 + for i in range(2 + offset, n + 1, 2): + result *= i + return result + + +@numba.jit(nopython=True, cache=True) +def lmn_factors(l: int, m: int, n: int) -> float: + return 1 / np.sqrt( + factorial2(2 * l - 1) * factorial2(2 * m - 1) * factorial2(2 * n - 1) + ) + + +@numba.jit(nopython=True, cache=True) +def multipole1d_(i, j, e, px, pa, pb, pr, base): + """1d-multipole integral.""" + # Base case + if (i < 0) or (j < 0) or (e < 0) or ((i == 0) and (j == 0) and (e == 0)): + # return np.sqrt(np.pi / px) + return base + # Decrement bra + elif i > 0: + return pa * multipole1d(i - 1, j, e, px, pa, pb, pr, base) + 1 / (2 * px) * ( + (i - 1) * multipole1d(i - 2, j, e, px, pa, pb, pr, base) + + j * multipole1d(i - 1, j - 1, e, px, pa, pb, pr, base) + + e * multipole1d(i - 1, j, e - 1, px, pa, pb, pr, base) + ) + # Decrement ket + elif j > 0: + return pb * multipole1d(i, j - 1, e, px, pa, pb, pr, base) + 1 / (2 * px) * ( + i * multipole1d(i - 1, j - 1, e, px, pa, pb, pr, base) + + (j - 1) * multipole1d(i, j - 2, e, px, pa, pb, pr, base) + + e * multipole1d(i, j - 1, e - 1, px, pa, pb, pr, base) + ) + # Decrement multipole order + # e > 0 + else: + return pr * multipole1d(i, j, e - 1, px, pa, pb, pr, base) + 1 / (2 * px) * ( + i * multipole1d(i - 1, j, e - 1, px, pa, pb, pr, base) + + j * multipole1d(i, j - 1, e - 1, px, pa, pb, pr, base) + + (e - 1) * multipole1d(i, j, e - 2, px, pa, pb, pr, base) + ) + + +@numba.jit(nopython=True, nogil=True, cache=True) +def multipole1d(i, j, e, px, pa, pb, pr, base): + """1d-multipole integral.""" + + def vrr(i, j, e): + return multipole1d(i, j, e, px, pa, pb, pr, base) + + # Base case + if (i < 0) or (j < 0) or (e < 0) or ((i == 0) and (j == 0) and (e == 0)): + # return np.sqrt(np.pi / px) + return base + # Decrement bra + elif i > 0: + return pa * vrr(i - 1, j, e) + 1 / (2 * px) * ( + (i - 1) * vrr(i - 2, j, e) + + j * vrr(i - 1, j - 1, e) + + e * vrr(i - 1, j, e - 1) + ) + # Decrement ket + elif j > 0: + return pb * vrr(i, j - 1, e) + 1 / (2 * px) * ( + i * vrr(i - 1, j - 1, e) + + (j - 1) * vrr(i, j - 2, e) + + e * vrr(i, j - 1, e - 1) + ) + # Decrement multipole order + # e > 0 + else: + return pr * vrr(i, j, e - 1) + 1 / (2 * px) * ( + i * vrr(i - 1, j, e - 1) + + j * vrr(i, j - 1, e - 1) + + (e - 1) * vrr(i, j, e - 2) + ) + + +@numba.jit(nopython=True, cache=True, nogil=True, fastmath=True) +def multipole3d(La, Lb, Le, axs, das, A, bxs, dbs, B, R, exp_thresh=-36.0): + """3d-multipole integral.""" + # Angular momenta of the different shells + lmns_a = canonical_order(La) + lmns_b = canonical_order(Lb) + lmns_e = canonical_order(Le) + na = len(lmns_a) + nb = len(lmns_b) + ne = len(lmns_e) + # Final integrals + integrals = np.zeros((na, nb, ne)) + + # Construct angular momenta dependent normalization factors. Actually they + # should be precalculated somehow. + lmn_norms = np.zeros((na, nb)) + for i in range(na): + la, ma, na_ = lmns_a[i] + flmna = lmn_factors(la, ma, na_) + for j in range(nb): + lb, mb, nb_ = lmns_b[j] + lmn_norms[i, j] = flmna * lmn_factors(lb, mb, nb_) + + # Number of primitives in bra and ket + nprimsa = len(axs) + nprimsb = len(bxs) + AB = A - B + AB2 = AB**2 + AB2sum = AB2.sum() + + # Determine most diffuse exponent pair in both shells and calulcate + # the associated exp-argument. When this is already very small then + # we skip the whole shell pair. + ax_min = axs.min() + bx_min = bxs.min() + min_exp_arg = -(ax_min * bx_min) / (ax_min + bx_min) * AB2sum + if min_exp_arg <= exp_thresh: + return integrals + + # Loop over pairs of primitives + for a in range(nprimsa): + ax = axs[a] + da = das[a] + for b in range(nprimsb): + bx = bxs[b] + dadb = da * dbs[b] + + px = ax + bx + mux = ax * bx / px + exp_arg = -mux * AB2sum + # Skip primitive pair when exp-argument is very small + if exp_arg <= exp_thresh: + continue + K = np.exp(exp_arg) + P = (ax * A + bx * B) / px + PA = P - A + PB = P - B + PR = P - R + base = np.sqrt(np.pi / px) + # Loop over triples of angular momenta + for i in range(na): + lmna = lmns_a[i] + for j in range(nb): + lmnb = lmns_b[j] + for k in range(ne): + lmne = lmns_e[k] + # Build up x-, y- and z-terms + tmp = 1.0 + for m in range(3): + tmp *= multipole1d( + lmna[m], lmnb[m], lmne[m], px, PA[m], PB[m], PR[m], base + ) + integrals[i, j, k] += dadb * K * tmp + + # Apply lmn-dependent basis function normalization + for i in range(na): + for j in range(nb): + integrals[i, j] *= lmn_norms[i, j] + return integrals + + +@numba.jit(parallel=True, nopython=True, cache=True) +def get_multipole_ints_cart_numba( + Le, + R, + shells_a, + shells_b, + symmetric, +): + components = 2 * Le + 1 + + tot_size_a = 0 + for shell in shells_a: + tot_size_a += shell.size + + tot_size_b = 0 + for shell in shells_b: + tot_size_b += shell.size + + # Allocate final integral array + integrals = np.zeros((tot_size_a, tot_size_b, components)) + shells_b = shells_a + nshells_a = len(shells_a) + nshells_b = len(shells_b) + + # Start parallel loop over contracted gaussians in shells_a + for i in numba.prange(nshells_a): + shell_a = shells_a[i] + La, A, _, das, axs, indexa, sizea = shell_a.as_tuple() + slicea = slice(indexa, indexa + sizea) + + # Start loop over contracted gaussians in shells_b + if not symmetric: + i = 0 + for j in range(i, nshells_b): + shell_b = shells_b[j] + Lb, B, _, dbs, bxs, indexb, sizeb = shell_b.as_tuple() + sliceb = slice(indexb, indexb + sizeb) + + result = multipole3d(La, Lb, Le, axs, das, A, bxs, dbs, B, R) + integrals[slicea, sliceb, :] = result + + if symmetric and (i != j): + for k in range(indexa, indexa + sizea): + for l in range(indexb, indexb + sizeb): + integrals[l, k, :] = integrals[k, l, :] + # End loop over contracted gaussians in shells_b + # End loop over contracted gaussians in shells_a + return integrals + + +def get_multipole_ints_cart( + Le: int, + R: np.ndarray, + shellstructs_a, + shellstructs_b=None, +): + symmetric = shellstructs_b is None + if symmetric: + shellstructs_b = shellstructs_a + + integrals = get_multipole_ints_cart_numba( + Le, + R, + shellstructs_a, + shellstructs_b, + symmetric, + ) + if integrals.shape[2] == 1: + integrals = np.squeeze(integrals, axis=2) + return integrals + + +def get_1el_ints_cart(shells, func_dict, shells_b, **kwargs): + R = kwargs.get("R", np.zeros(3)) + Le = func_dict + return get_multipole_ints_cart(Le, R, shells, shells_b) + + +def get_multipole_ints_sph(Le: int, R: np.ndarray, shells_a, shells_b=None): + """This function expects pysisyphus.Shells not Shellstructs""" + shellstructs_a = shells_a.numba_shellstructs + if shells_b is not None: + shellstructs_b = shells_b.numba_shellstructs + else: + shellstructs_b = shellstructs_a + + integrals_cart = get_multipole_ints_cart(Le, R, shellstructs_a, shellstructs_b) + + c2s_coeffs_a = shells_a.reorder_c2s_coeffs + if shells_b is not None: + c2s_coeffs_b = shells_b.reorder_c2s_coeffs + else: + c2s_coeffs_b = c2s_coeffs_a + + int_matrix_sph = np.einsum( + "ij,jk...,kl->il...", + c2s_coeffs_a, + integrals_cart, + c2s_coeffs_b.T, + optimize="greedy", + ) + return int_matrix_sph diff --git a/pysisyphus/wavefunction/shells.py b/pysisyphus/wavefunction/shells.py index 2fa5f4d498..e2f00d8bd1 100644 --- a/pysisyphus/wavefunction/shells.py +++ b/pysisyphus/wavefunction/shells.py @@ -8,7 +8,7 @@ from math import sqrt, log, pi from pathlib import Path import textwrap -from typing import List, Literal, Union +from typing import List, Literal, Optional, Union import warnings @@ -49,15 +49,15 @@ class Shell: def __init__( self, L: int, - center: NDArray[float], - coeffs: NDArray[float], - exps: NDArray[float], + center: np.ndarray, + coeffs: np.ndarray, + exps: np.ndarray, center_ind: int, - atomic_num=None, + atomic_num: Optional[int] = None, # TODO: sph_index and cart_index?! - shell_index=None, - index=None, - sph_index=None, + shell_index: Optional[int] = None, + index: Optional[int] = None, + sph_index: Optional[int] = None, # min_coeff: float = 1e-8, ): self.L = get_l(L) @@ -195,7 +195,7 @@ def __repr__(self): Ordering = Literal["native", "pysis"] -IntegralBackend = Enum("IntegralBackend", ["PYTHON", "NUMBA"]) +IntegralBackend = Enum("IntegralBackend", ["PYTHON", "NUMBA", "NUMBA_MULTIPOLE"]) _backend_modules = { IntegralBackend.PYTHON: backend_python, } @@ -225,6 +225,7 @@ def __init__( # ValueError is raised when backend is a string, not an Enum except ValueError: self.backend = IntegralBackend[backend.upper()] + # Only ever import (and compile) numba backend when actually requested, # as the compilation/setup takes quite some time. if self.backend == IntegralBackend.NUMBA: @@ -238,6 +239,20 @@ def __init__( "not installed!.\n Falling back to python backend." ) self.backend = IntegralBackend.PYTHON + elif self.backend == IntegralBackend.NUMBA_MULTIPOLE: + try: + from pysisyphus.wavefunction import backend_numba_multipole + + _backend_modules[IntegralBackend.NUMBA_MULTIPOLE] = ( + backend_numba_multipole + ) + except ModuleNotFoundError: + print( + "numba integral backend was requested but numba package is " + "not installed!.\n Falling back to python backend." + ) + self.backend = IntegralBackend.PYTHON + # Pick the actual backend module self.backend_module = _backend_modules[self.backend] # End integral backend setup @@ -284,6 +299,7 @@ def __init__( self.get_1el_ints_sph = self.memory.cache(self.get_1el_ints_sph) self._numba_shells = None + self._numba_shellstructs = None @property def nprims(self): @@ -315,6 +331,12 @@ def as_numba_shells(self): return backend_numba.to_numba_shells(self) + @property + def numba_shellstructs(self): + if self._numba_shellstructs is None: + self._numba_shellstructs = self.as_numba_shellstructs() + return self._numba_shellstructs + def as_numba_shellstructs(self): from pysisyphus.wavefunction import backend_numba @@ -685,6 +707,11 @@ def get_1el_ints_cart( shells = self.numba_shells if other is not None: other = other.numba_shells + elif self.backend == IntegralBackend.NUMBA_MULTIPOLE: + shells = self.numba_shellstructs + if other is not None: + other = other.numba_shellstructs + return self.backend_module.get_1el_ints_cart( shells, func_dict,