From 67e8564c399e9acd7421f905f22ce17e749198c4 Mon Sep 17 00:00:00 2001 From: Johannes Steinmetzer Date: Wed, 24 Jan 2024 14:12:34 +0100 Subject: [PATCH] mod: integrals over shells are now stored in 1d arrays FortranRenderer is not yet updated --- sympleints/NumbaRenderer.py | 30 ++++++++++++----- sympleints/PythonRenderer.py | 10 ++++++ sympleints/Renderer.py | 6 ++-- sympleints/defs/coulomb.py | 4 +++ sympleints/main.py | 43 +++--------------------- sympleints/templates/numba_driver.tpl | 7 ++-- sympleints/templates/numba_equi_func.tpl | 9 +++-- sympleints/templates/numba_func.tpl | 2 +- sympleints/templates/numba_if_driver.tpl | 28 +++++++++++++++ 9 files changed, 81 insertions(+), 58 deletions(-) create mode 100644 sympleints/templates/numba_if_driver.tpl diff --git a/sympleints/NumbaRenderer.py b/sympleints/NumbaRenderer.py index 7da4b74..258f368 100644 --- a/sympleints/NumbaRenderer.py +++ b/sympleints/NumbaRenderer.py @@ -6,15 +6,9 @@ ArgKind.EXPO: "f8", ArgKind.CONTR: "f8", ArgKind.CENTER: "f8[:]", - ArgKind.RESULT1: "f8[:]", - ArgKind.RESULT2: "f8[:, :]", - ArgKind.RESULT3: "f8[:, :, :]", - ArgKind.RESULT4: "f8[:, :, :, :]", + ArgKind.RESULT1: "f8[::1]", } -# TODO: investigate contiguity by using the ::1 -# See: https://numba.pydata.org/numba-doc/latest/reference/types.html - _container_type_map = { ArgKind.EXPO: "f8[:]", ArgKind.CONTR: "f8[:]", @@ -26,7 +20,7 @@ def func_type_from_functions(functions): """Numba function signature.""" args = [_type_map[arg_kind] for arg_kind in functions.full_arg_kinds] # Add result type - args += [_type_map[functions.result_kind]] + args += [_type_map[ArgKind.RESULT1]] args_str = ", ".join(args) signature = f"numba.types.void({args_str})" return signature @@ -38,7 +32,7 @@ def driver_func_type_from_functions(functions): # Prepend angular momenta that will be passed to the driver args = ["i8" for _ in range(functions.nbfs)] + args args_str = ", ".join(args) - result = _type_map[functions.result_kind] + result = _type_map[ArgKind.RESULT1] signature = f"{result}({args_str}, func_dict_type)" return signature @@ -52,6 +46,7 @@ class NumbaRenderer(PythonRenderer): "equi_func": "numba_equi_func.tpl", "func_dict": "py_func_dict.tpl", "driver": "numba_driver.tpl", + # "driver": "numba_if_driver.tpl", "module": "numba_module.tpl", } _suffix = "_numba" @@ -87,9 +82,26 @@ def render_driver_func(self, functions, rendered_funcs): ) return rendered + def render_if_driver_func(self, functions, rendered_funcs): + tpl = self.get_template(key="driver") + conds_funcs = list() + Ls = ("La", "Lb", "Lc", "Ld")[: functions.nbfs] + for rfunc in rendered_funcs: + Lconds = [f"{L} == {Lval}" for L, Lval in zip(Ls, rfunc.Ls)] + conds_funcs.append((Lconds, rfunc.name)) + + rendered = tpl.render( + name=functions.name, + Ls=Ls, + args=functions.full_args + ["result"], + conds_funcs=conds_funcs, + ) + return rendered + def render_module(self, functions, rendered_funcs, **tpl_kwargs): func_type = func_type_from_functions(functions) driver_func = self.render_driver_func(functions, rendered_funcs) + # driver_func = self.render_if_driver_func(functions, rendered_funcs) tpl_kwargs.update( { "func_type": func_type, diff --git a/sympleints/PythonRenderer.py b/sympleints/PythonRenderer.py index dcc2deb..b41743c 100644 --- a/sympleints/PythonRenderer.py +++ b/sympleints/PythonRenderer.py @@ -38,6 +38,7 @@ def render_function( shape_iter, args, name, + L_tots, doc_str="", ): # This allows using the 'boys' function without producing an error @@ -71,6 +72,8 @@ def render_function( n_return_vals=len(reduced), doc_str=doc_str, shape=shape, + functions=functions, + primitive=functions.primitive, ) return rendered @@ -89,6 +92,9 @@ def render_equi_function( tpl = self.get_template(key="equi_func") nbfs = functions.nbfs assert nbfs in (2, 3), "Implement other cases in template!" + if functions.ncomponents == 1: + from_axes = tuple([i - 1 for i in from_axes[1:]]) + to_axes = tuple([i - 1 for i in to_axes[1:]]) rendered = tpl.render( equi_name=equi_name, equi_args=equi_args, @@ -97,6 +103,8 @@ def render_equi_function( from_axes=from_axes, to_axes=to_axes, reshape=reshape, + functions=functions, + primitive=functions.primitive, ) return rendered @@ -114,6 +122,8 @@ def render_module(self, functions, rendered_funcs, **tpl_kwargs): "boys": functions.boys_func, "funcs": rendered_funcs, "func_dict": func_dict, + "name": functions.name, + "args": functions.full_args, } _tpl_kwargs.update(tpl_kwargs) rendered = tpl.render(**_tpl_kwargs) diff --git a/sympleints/Renderer.py b/sympleints/Renderer.py index a45cdca..b7ea016 100644 --- a/sympleints/Renderer.py +++ b/sympleints/Renderer.py @@ -53,6 +53,7 @@ def render_function( shape_iter, args, name, + L_tots, doc_str="", ): raise NotImplementedError @@ -62,9 +63,9 @@ def render_equi_function( self, functions, name, - act_name, + equi_name, equi_inds, - L_tots, + reshape, from_axes, to_axes, ): @@ -107,6 +108,7 @@ def render_functions(self, functions: Functions): shape_iter, args=args, name=name, + L_tots=L_tots, doc_str=doc_str, ) dur = time.time() - start diff --git a/sympleints/defs/coulomb.py b/sympleints/defs/coulomb.py index 009b4ca..906935e 100644 --- a/sympleints/defs/coulomb.py +++ b/sympleints/defs/coulomb.py @@ -341,14 +341,17 @@ def recur_vrr_aux_sph(cart_ind): return recur_vrr(2) +""" class ThreeCenterTwoElectron(ThreeCenterTwoElectronBase): aux_vrr = "cart" +""" class ThreeCenterTwoElectronSph(ThreeCenterTwoElectronBase): aux_vrr = "sph" +""" class ThreeCenterTwoElectronShell(Function): @classmethod def eval(cls, La_tot, Lb_tot, Lc_tot, a, b, c, A, B, C): @@ -359,6 +362,7 @@ def eval(cls, La_tot, Lb_tot, Lc_tot, a, b, c, A, B, C): ] # print(ThreeCenterTwoElectron.eval.cache_info()) return exprs, lmns +""" class ThreeCenterTwoElectronSphShell(Function): diff --git a/sympleints/main.py b/sympleints/main.py index 35d9147..67d3021 100644 --- a/sympleints/main.py +++ b/sympleints/main.py @@ -70,7 +70,6 @@ from sympleints.defs.coulomb import ( CoulombShell, TwoCenterTwoElectronShell, - # ThreeCenterTwoElectronShell, ThreeCenterTwoElectronSphShell, ) @@ -105,7 +104,6 @@ "kin", "coul", "2c2e", - # "3c2e", # not really practical "3c2e_sph", ) Normalization = Enum("Normalization", ["PGTO", "CGTO", "NONE"]) @@ -453,7 +451,7 @@ def run( ) renderers = [ PythonRenderer(), - # NumbaRenderer(), + NumbaRenderer(), # FortranRenderer(), ] results = dict() @@ -919,38 +917,6 @@ def doc_func(L_tots): # Three-center two-electron repulsion integrals # ################################################# - """ - # NOT YET UPDATED! - def _3center2electron(): - def _3center2el_doc_func(L_tots): - La_tot, Lb_tot, Lc_tot = L_tots - shell_a = L_MAP[La_tot] - shell_b = L_MAP[Lb_tot] - shell_c = L_MAP[Lc_tot] - return ( - f"{INT_KIND} ({shell_a}{shell_b}|{shell_c}) " - "three-center two-electron repulsion integral." - ) - - _3center2el_ints_Ls = integral_gen( - lambda La_tot, Lb_tot, Lc_tot: ThreeCenterTwoElectronShell( - La_tot, Lb_tot, Lc_tot, ax, bx, cx, center_A, center_B, center_C - ), - (l_max, l_max, l_aux_max), - (ax, bx, cx), - "_3center2el3d", - (A_map, B_map, C_map), - ) - write_render( - _3center2el_ints_Ls, - (ax, da, A, bx, db, B, cx, dc, C), - "_3center2el3d", - _3center2el_doc_func, - c=False, - py_kwargs={"add_imports": boys_import}, - ) - """ - def _3center2electron_sph(): def doc_func(L_tots): La_tot, Lb_tot, Lc_tot = L_tots @@ -1045,19 +1011,21 @@ def doc_func(L_tots): """ funcs = { + # Functions "prefactor": prefactor, "gto": gto, # Cartesian Gaussian-type-orbital for density evaluation + # Integrals "ovlp": overlap, # Overlap integrals "dpm": dipole, # Linear moment (dipole) integrals "dqpm": diag_quadrupole, # Diagonal part of the quadrupole tensor "qpm": quadrupole, # Quadratic moment (quadrupole) integrals "multi_sph": multipole_sph, # Integrals for distributed multipole analysis "kin": kinetic, # Kinetic energy integrals + # "4covlp": fourcenter_overlap, # Four center overlap integral + # Utilizing Boys-function below "coul": coulomb, # 1-electron Coulomb integrals "2c2e": _2center2electron, # 2-center-2-electron density fitting integrals - # "3c2e": _3center2electron, # 3-center-2-electron integrals for DF "3c2e_sph": _3center2electron_sph, # Sph. 3-center-2-electron DF integrals - # "4covlp": fourcenter_overlap, # Four center overlap integral } assert set(funcs.keys()) == set( KEYS @@ -1127,7 +1095,6 @@ def parse_args(args): ) parser.add_argument( "--boys-func", - # default="pysisyphus.wavefunction.ints.boys", default="sympleints.testing.boys", help="Which Boys-function to use.", ) diff --git a/sympleints/templates/numba_driver.tpl b/sympleints/templates/numba_driver.tpl index a800846..9a69e8f 100644 --- a/sympleints/templates/numba_driver.tpl +++ b/sympleints/templates/numba_driver.tpl @@ -1,8 +1,5 @@ int_tuple_type = numba.types.UniTuple(i8, {{ nbfs }}) func_dict_type = numba.types.DictType(int_tuple_type, func_type) -driver_func_type = numba.types.FunctionType( - {{ driver_func_type }} -) # Sadly, this function can't be cached. @numba.jit(func_dict_type(), nopython=True, cache=True) @@ -26,6 +23,10 @@ def get_func_dict(): {# +driver_func_type = numba.types.FunctionType( + {{ driver_func_type }} +) + @numba.jit( driver_func_type.signature, nopython=True, diff --git a/sympleints/templates/numba_equi_func.tpl b/sympleints/templates/numba_equi_func.tpl index ec1159b..4fc2300 100644 --- a/sympleints/templates/numba_equi_func.tpl +++ b/sympleints/templates/numba_equi_func.tpl @@ -8,9 +8,8 @@ def {{ equi_name }}({{ args }}, result): """See docstring of {{ name }}.""" - # np.moveaxis() is not yet supported by numba as of 0.58.1 - # result = numpy.moveaxis({{ name }}({{ equi_args }}), {{ from_axes }}, {{ to_axes }}) - # Call equivalent function and write to result - {{ name }}({{ equi_args }}, result) - result = numpy.transpose(result, axes={{ from_axes }}) + tmp = numpy.zeros_like(result) + {{ name }}({{ equi_args }}, tmp) + result[:] {% if not primitive %}+{% endif %}= numpy.transpose(tmp.reshape({{ reshape|join(", ") }}), axes={{ from_axes }}).flatten() + diff --git a/sympleints/templates/numba_func.tpl b/sympleints/templates/numba_func.tpl index 396f33e..3da7a69 100644 --- a/sympleints/templates/numba_func.tpl +++ b/sympleints/templates/numba_func.tpl @@ -16,5 +16,5 @@ def {{ name }}({{ args }}, result): # {{ n_return_vals }} item(s) {% for inds, res_line in results_iter %} - result[{{ inds|join(", ")}}] += {{ res_line }} + result[{{ loop.index0 }}] {% if not primitive %}+{% endif %}= {{ res_line }} {% endfor %} diff --git a/sympleints/templates/numba_if_driver.tpl b/sympleints/templates/numba_if_driver.tpl new file mode 100644 index 0000000..4324fce --- /dev/null +++ b/sympleints/templates/numba_if_driver.tpl @@ -0,0 +1,28 @@ +{# +@numba.jit( + nopython=True, + nogil=True, + fastmath=True, + cache=True, +) +def {{ name }}({{ Ls|join(", ") }}, {{ args|join(", ") }}): + {% for Lconds, func_name in conds_funcs %} + {%+ if loop.index0 > 0 %}el{% endif %}if {{ Lconds|join(" and ") }}: + {{ func_name }}({{ args|join(", ") }}) + {% endfor %} + else: + {{ args[-1] }}[:] = numpy.nan +#} + +@numba.jit( + nopython=True, + nogil=True, + fastmath=True, + cache=True, +) +def {{ name }}({{ Ls|join(", ") }}, {{ args|join(", ") }}): + {% for Lconds, func_name in conds_funcs %} + {%+ if loop.index0 > 0 %}el{% endif %}if {{ Lconds|join(" and ") }}: + func = {{ func_name }} + {% endfor %} + func({{ args|join(", ") }})