Skip to content

Commit

Permalink
mod: integrals over shells are now stored in 1d arrays
Browse files Browse the repository at this point in the history
FortranRenderer is not yet updated
  • Loading branch information
eljost committed Jan 24, 2024
1 parent 45d11f2 commit 67e8564
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 58 deletions.
30 changes: 21 additions & 9 deletions sympleints/NumbaRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@
ArgKind.EXPO: "f8",
ArgKind.CONTR: "f8",
ArgKind.CENTER: "f8[:]",
ArgKind.RESULT1: "f8[:]",
ArgKind.RESULT2: "f8[:, :]",
ArgKind.RESULT3: "f8[:, :, :]",
ArgKind.RESULT4: "f8[:, :, :, :]",
ArgKind.RESULT1: "f8[::1]",
}

# TODO: investigate contiguity by using the ::1
# See: https://numba.pydata.org/numba-doc/latest/reference/types.html

_container_type_map = {
ArgKind.EXPO: "f8[:]",
ArgKind.CONTR: "f8[:]",
Expand All @@ -26,7 +20,7 @@ def func_type_from_functions(functions):
"""Numba function signature."""
args = [_type_map[arg_kind] for arg_kind in functions.full_arg_kinds]
# Add result type
args += [_type_map[functions.result_kind]]
args += [_type_map[ArgKind.RESULT1]]
args_str = ", ".join(args)
signature = f"numba.types.void({args_str})"
return signature
Expand All @@ -38,7 +32,7 @@ def driver_func_type_from_functions(functions):
# Prepend angular momenta that will be passed to the driver
args = ["i8" for _ in range(functions.nbfs)] + args
args_str = ", ".join(args)
result = _type_map[functions.result_kind]
result = _type_map[ArgKind.RESULT1]
signature = f"{result}({args_str}, func_dict_type)"
return signature

Expand All @@ -52,6 +46,7 @@ class NumbaRenderer(PythonRenderer):
"equi_func": "numba_equi_func.tpl",
"func_dict": "py_func_dict.tpl",
"driver": "numba_driver.tpl",
# "driver": "numba_if_driver.tpl",
"module": "numba_module.tpl",
}
_suffix = "_numba"
Expand Down Expand Up @@ -87,9 +82,26 @@ def render_driver_func(self, functions, rendered_funcs):
)
return rendered

def render_if_driver_func(self, functions, rendered_funcs):
tpl = self.get_template(key="driver")
conds_funcs = list()
Ls = ("La", "Lb", "Lc", "Ld")[: functions.nbfs]
for rfunc in rendered_funcs:
Lconds = [f"{L} == {Lval}" for L, Lval in zip(Ls, rfunc.Ls)]
conds_funcs.append((Lconds, rfunc.name))

rendered = tpl.render(
name=functions.name,
Ls=Ls,
args=functions.full_args + ["result"],
conds_funcs=conds_funcs,
)
return rendered

def render_module(self, functions, rendered_funcs, **tpl_kwargs):
func_type = func_type_from_functions(functions)
driver_func = self.render_driver_func(functions, rendered_funcs)
# driver_func = self.render_if_driver_func(functions, rendered_funcs)
tpl_kwargs.update(
{
"func_type": func_type,
Expand Down
10 changes: 10 additions & 0 deletions sympleints/PythonRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def render_function(
shape_iter,
args,
name,
L_tots,
doc_str="",
):
# This allows using the 'boys' function without producing an error
Expand Down Expand Up @@ -71,6 +72,8 @@ def render_function(
n_return_vals=len(reduced),
doc_str=doc_str,
shape=shape,
functions=functions,
primitive=functions.primitive,
)
return rendered

Expand All @@ -89,6 +92,9 @@ def render_equi_function(
tpl = self.get_template(key="equi_func")
nbfs = functions.nbfs
assert nbfs in (2, 3), "Implement other cases in template!"
if functions.ncomponents == 1:
from_axes = tuple([i - 1 for i in from_axes[1:]])
to_axes = tuple([i - 1 for i in to_axes[1:]])
rendered = tpl.render(
equi_name=equi_name,
equi_args=equi_args,
Expand All @@ -97,6 +103,8 @@ def render_equi_function(
from_axes=from_axes,
to_axes=to_axes,
reshape=reshape,
functions=functions,
primitive=functions.primitive,
)
return rendered

Expand All @@ -114,6 +122,8 @@ def render_module(self, functions, rendered_funcs, **tpl_kwargs):
"boys": functions.boys_func,
"funcs": rendered_funcs,
"func_dict": func_dict,
"name": functions.name,
"args": functions.full_args,
}
_tpl_kwargs.update(tpl_kwargs)
rendered = tpl.render(**_tpl_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions sympleints/Renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def render_function(
shape_iter,
args,
name,
L_tots,
doc_str="",
):
raise NotImplementedError
Expand All @@ -62,9 +63,9 @@ def render_equi_function(
self,
functions,
name,
act_name,
equi_name,
equi_inds,
L_tots,
reshape,
from_axes,
to_axes,
):
Expand Down Expand Up @@ -107,6 +108,7 @@ def render_functions(self, functions: Functions):
shape_iter,
args=args,
name=name,
L_tots=L_tots,
doc_str=doc_str,
)
dur = time.time() - start
Expand Down
4 changes: 4 additions & 0 deletions sympleints/defs/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,17 @@ def recur_vrr_aux_sph(cart_ind):
return recur_vrr(2)


"""
class ThreeCenterTwoElectron(ThreeCenterTwoElectronBase):
aux_vrr = "cart"
"""


class ThreeCenterTwoElectronSph(ThreeCenterTwoElectronBase):
aux_vrr = "sph"


"""
class ThreeCenterTwoElectronShell(Function):
@classmethod
def eval(cls, La_tot, Lb_tot, Lc_tot, a, b, c, A, B, C):
Expand All @@ -359,6 +362,7 @@ def eval(cls, La_tot, Lb_tot, Lc_tot, a, b, c, A, B, C):
]
# print(ThreeCenterTwoElectron.eval.cache_info())
return exprs, lmns
"""


class ThreeCenterTwoElectronSphShell(Function):
Expand Down
43 changes: 5 additions & 38 deletions sympleints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
from sympleints.defs.coulomb import (
CoulombShell,
TwoCenterTwoElectronShell,
# ThreeCenterTwoElectronShell,
ThreeCenterTwoElectronSphShell,
)

Expand Down Expand Up @@ -105,7 +104,6 @@
"kin",
"coul",
"2c2e",
# "3c2e", # not really practical
"3c2e_sph",
)
Normalization = Enum("Normalization", ["PGTO", "CGTO", "NONE"])
Expand Down Expand Up @@ -453,7 +451,7 @@ def run(
)
renderers = [
PythonRenderer(),
# NumbaRenderer(),
NumbaRenderer(),
# FortranRenderer(),
]
results = dict()
Expand Down Expand Up @@ -919,38 +917,6 @@ def doc_func(L_tots):
# Three-center two-electron repulsion integrals #
#################################################

"""
# NOT YET UPDATED!
def _3center2electron():
def _3center2el_doc_func(L_tots):
La_tot, Lb_tot, Lc_tot = L_tots
shell_a = L_MAP[La_tot]
shell_b = L_MAP[Lb_tot]
shell_c = L_MAP[Lc_tot]
return (
f"{INT_KIND} ({shell_a}{shell_b}|{shell_c}) "
"three-center two-electron repulsion integral."
)
_3center2el_ints_Ls = integral_gen(
lambda La_tot, Lb_tot, Lc_tot: ThreeCenterTwoElectronShell(
La_tot, Lb_tot, Lc_tot, ax, bx, cx, center_A, center_B, center_C
),
(l_max, l_max, l_aux_max),
(ax, bx, cx),
"_3center2el3d",
(A_map, B_map, C_map),
)
write_render(
_3center2el_ints_Ls,
(ax, da, A, bx, db, B, cx, dc, C),
"_3center2el3d",
_3center2el_doc_func,
c=False,
py_kwargs={"add_imports": boys_import},
)
"""

def _3center2electron_sph():
def doc_func(L_tots):
La_tot, Lb_tot, Lc_tot = L_tots
Expand Down Expand Up @@ -1045,19 +1011,21 @@ def doc_func(L_tots):
"""

funcs = {
# Functions
"prefactor": prefactor,
"gto": gto, # Cartesian Gaussian-type-orbital for density evaluation
# Integrals
"ovlp": overlap, # Overlap integrals
"dpm": dipole, # Linear moment (dipole) integrals
"dqpm": diag_quadrupole, # Diagonal part of the quadrupole tensor
"qpm": quadrupole, # Quadratic moment (quadrupole) integrals
"multi_sph": multipole_sph, # Integrals for distributed multipole analysis
"kin": kinetic, # Kinetic energy integrals
# "4covlp": fourcenter_overlap, # Four center overlap integral
# Utilizing Boys-function below
"coul": coulomb, # 1-electron Coulomb integrals
"2c2e": _2center2electron, # 2-center-2-electron density fitting integrals
# "3c2e": _3center2electron, # 3-center-2-electron integrals for DF
"3c2e_sph": _3center2electron_sph, # Sph. 3-center-2-electron DF integrals
# "4covlp": fourcenter_overlap, # Four center overlap integral
}
assert set(funcs.keys()) == set(
KEYS
Expand Down Expand Up @@ -1127,7 +1095,6 @@ def parse_args(args):
)
parser.add_argument(
"--boys-func",
# default="pysisyphus.wavefunction.ints.boys",
default="sympleints.testing.boys",
help="Which Boys-function to use.",
)
Expand Down
7 changes: 4 additions & 3 deletions sympleints/templates/numba_driver.tpl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
int_tuple_type = numba.types.UniTuple(i8, {{ nbfs }})
func_dict_type = numba.types.DictType(int_tuple_type, func_type)
driver_func_type = numba.types.FunctionType(
{{ driver_func_type }}
)

# Sadly, this function can't be cached.
@numba.jit(func_dict_type(), nopython=True, cache=True)
Expand All @@ -26,6 +23,10 @@ def get_func_dict():


{#
driver_func_type = numba.types.FunctionType(
{{ driver_func_type }}
)

@numba.jit(
driver_func_type.signature,
nopython=True,
Expand Down
9 changes: 4 additions & 5 deletions sympleints/templates/numba_equi_func.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
def {{ equi_name }}({{ args }}, result):
"""See docstring of {{ name }}."""

# np.moveaxis() is not yet supported by numba as of 0.58.1
# result = numpy.moveaxis({{ name }}({{ equi_args }}), {{ from_axes }}, {{ to_axes }})

# Call equivalent function and write to result
{{ name }}({{ equi_args }}, result)
result = numpy.transpose(result, axes={{ from_axes }})
tmp = numpy.zeros_like(result)
{{ name }}({{ equi_args }}, tmp)
result[:] {% if not primitive %}+{% endif %}= numpy.transpose(tmp.reshape({{ reshape|join(", ") }}), axes={{ from_axes }}).flatten()

2 changes: 1 addition & 1 deletion sympleints/templates/numba_func.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ def {{ name }}({{ args }}, result):

# {{ n_return_vals }} item(s)
{% for inds, res_line in results_iter %}
result[{{ inds|join(", ")}}] += {{ res_line }}
result[{{ loop.index0 }}] {% if not primitive %}+{% endif %}= {{ res_line }}
{% endfor %}
28 changes: 28 additions & 0 deletions sympleints/templates/numba_if_driver.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{#
@numba.jit(
nopython=True,
nogil=True,
fastmath=True,
cache=True,
)
def {{ name }}({{ Ls|join(", ") }}, {{ args|join(", ") }}):
{% for Lconds, func_name in conds_funcs %}
{%+ if loop.index0 > 0 %}el{% endif %}if {{ Lconds|join(" and ") }}:
{{ func_name }}({{ args|join(", ") }})
{% endfor %}
else:
{{ args[-1] }}[:] = numpy.nan
#}

@numba.jit(
nopython=True,
nogil=True,
fastmath=True,
cache=True,
)
def {{ name }}({{ Ls|join(", ") }}, {{ args|join(", ") }}):
{% for Lconds, func_name in conds_funcs %}
{%+ if loop.index0 > 0 %}el{% endif %}if {{ Lconds|join(" and ") }}:
func = {{ func_name }}
{% endfor %}
func({{ args|join(", ") }})

0 comments on commit 67e8564

Please sign in to comment.