diff --git a/sympleints/FortranRenderer.py b/sympleints/FortranRenderer.py index 01d9c8a..6336e14 100644 --- a/sympleints/FortranRenderer.py +++ b/sympleints/FortranRenderer.py @@ -7,7 +7,7 @@ from sympy.printing.fortran import FCodePrinter from sympleints.Renderer import Renderer -from sympleints.helpers import shell_shape_iter +from sympleints.helpers import shell_shape_iter, get_reorder_inds def format_with_fprettify(fortran: str): @@ -41,22 +41,6 @@ def format_with_fprettify(fortran: str): class FCodePrinterMod(FCodePrinter): boys_re = re.compile(r"boys\(([d\d\.]+),(.+)") - """ - def _print_Function(self, expr): - func = super()._print_Function(expr) - # Sympy prints everything as float (1.0d0, 2.0d0 etc.), even integers, but the - # first argument 'n' to the Boys function must be integer. - if func.startswith("boys"): - # Only try to fix calls to the Boys function when a double is present - if mobj := self.boys_re.match(func): - as_float = float(mobj.group(1).lower().replace("d", "e")) - as_int = int(as_float) - assert abs(as_float - as_int) <= 1e-14 - remainder = mobj.group(2) - func = f"boys({as_int},{remainder}" - return func - """ - def _print_AppliedUndef(self, expr): """For printing the Boys function. @@ -99,11 +83,15 @@ def make_fortran_comment(comment_str): class FortranRenderer(Renderer): ext = ".f90" real_kind = "kind=real64" - res_name = "res" + res_name = "result" language = "Fortran" _primitive = True _drop_dim = False + resort_func_dict = { + 2: ("resort_ba_ab", (2, 1)), + 3: ("resort_bac_abc", (2, 1)), + } def shell_shape_iter(self, *args, **kwargs): # Start indexing at 1, instead of 0. @@ -121,16 +109,27 @@ def get_argument_declaration(self, functions, contracted=False): coeffs=functions.coeffs, zip=zip, centers=functions.centers, - ref_center=functions.ref_center - if (contracted or functions.with_ref_center) - else None, + ref_center=( + functions.ref_center + if (contracted or functions.with_ref_center) + else None + ), res_name=self.res_name, res_dim=res_dim, ) return rendered def render_function( - self, functions, repls, reduced, shape, shape_iter, args, name, doc_str="" + self, + functions, + repls, + reduced, + shape, + shape_iter, + args, + name, + L_tots, + doc_str="", ): if (not functions.primitive) or (not self._primitive): warnings.warn("FortranRenderer always produces subroutines for primitives!") @@ -161,60 +160,51 @@ def render_function( ) return rendered - """ - def render_function( - self, functions, repls, reduced, shape, shape_iter, args, name, doc_str="" + def render_equi_function( + self, + functions, + name, + equi_name, + equi_inds, + shape, ): - # This allows using the 'boys' function without producing an error - print_settings = { - "allow_unknown_functions": True, - # Without disabling contract some expressions will raise ValueError. - "contract": False, - "standard": 2008, - "source_format": "free", - } - print_func = FCodePrinterMod(print_settings).doprint - assignments = [Assignment(lhs, rhs) for lhs, rhs in repls] - repl_lines = [print_func(as_) for as_ in assignments] - results = [print_func(red) for red in reduced] - res_len = len(reduced) - results_iter = zip(shape_iter, results) + """ + tpl = self.get_template(fn="fortran_equi_func.tpl") + nbfs = functions.nbfs + assert nbfs in (2,), "Implement other cases in template!" + # shape refers to the shape of the original function that is used to generate + # the equivalent function here. + reorder_inds = ( + get_reorder_inds( + shape, + functions.ncomponents, + herm_axes=functions.hermitian, + ) + + 1 # Add one as Fortran array indices start at 1. + ) + """ - tmps = [lhs for lhs, rhs in repls] - from sympy import IndexedBase - tmp = IndexedBase('tmp', shape=len(tmps)) - map_ = {} - for i, rhs_ in enumerate(tmps, 1): - map_[rhs_] = tmp[i - 1] - rhs_subs = list() - for _, rhs in repls: - rhs_subs.append(rhs.subs(map_)) - red_subs = list() - for red in reduced: - red_subs.append(red.subs(map_)) - results_sub = [print_func(red) for red in red_subs] - results_iter = zip(shape_iter, results_sub) - repl_rhss = [print_func(rhs) for rhs in rhs_subs] - doc_str = make_fortran_comment(doc_str) arg_declaration = self.get_argument_declaration(functions) - - tpl = self.get_template(fn="fortran_function_arr.tpl") + resort_func, new_order = self.resort_func_dict[functions.nbfs] + sizes = shape + sizes = [shape[i] for i in new_order] + equi_args = ", ".join(functions.full_args_for_bf_inds(equi_inds)) + args = ", ".join(functions.full_args) + ncomponents = functions.ncomponents + + tpl = self.get_template(fn="fortran_equi_func.tpl") rendered = tpl.render( - name=name, - args=functions.full_args, - doc_str=doc_str, - arg_declaration=arg_declaration, - res_name=self.res_name, - res_len=res_len, - assignments=assignments, - repl_lines=repl_lines, - results_iter=results_iter, - reduced=reduced, - kind=self.real_kind, - repl_rhs=repl_rhss, - ) + res_name=self.res_name, + arg_declaration=arg_declaration, + equi_name=equi_name, + equi_args=functions.full_args_for_bf_inds(equi_inds), + name=name, + args=functions.full_args, + sizes=sizes, + resort_func=resort_func, + ncomponents=ncomponents, + ) return rendered - """ def render_f_init(self, name, rendered_funcs, func_array_name="func_array"): tpl = self.get_template(fn="fortran_init.tpl") @@ -261,6 +251,8 @@ def loop_iter(counters, exps, coeffs): loops=loops, pntr_args=pntr_args, args=args, + kind=self.real_kind, + res_name=self.res_name, ) return rendered @@ -277,12 +269,13 @@ def render_module(self, functions, rendered_funcs, **tpl_kwargs): l_max = functions.l_max func_arr_dims = [f"0:{l_max}" for _ in range(functions.nbfs)] contr_driver = self.render_contracted_driver(functions) + resort_func, _ = self.resort_func_dict[functions.nbfs] tpl = self.get_template(fn="fortran_module.tpl") rendered = tpl.render( header=header, mod_name=mod_name, - boys=functions.boys, + boys=functions.boys_func, args=functions.full_args + [self.res_name], interface_name=interface_name, contr_driver=contr_driver, @@ -290,6 +283,9 @@ def render_module(self, functions, rendered_funcs, **tpl_kwargs): comment=comment, func_arr_dims=func_arr_dims, init=init, + res_name=self.res_name, + resort_func=resort_func, + kind=self.real_kind, funcs=rendered_funcs, ) rendered = format_with_fprettify(rendered) diff --git a/sympleints/templates/fortran_arg_declaration.tpl b/sympleints/templates/fortran_arg_declaration.tpl index 4bf196b..5e04f38 100644 --- a/sympleints/templates/fortran_arg_declaration.tpl +++ b/sympleints/templates/fortran_arg_declaration.tpl @@ -17,4 +17,4 @@ real({{ kind }}), intent(in), dimension(3) :: {{ centers|join(", ") }} real({{ kind }}), intent(in), dimension(3) :: {{ ref_center }} {% endif %} ! Return value -real({{ kind }}), intent(in out) :: {{ res_name }}({{ res_dim }}) +real({{ kind }}), intent(in out) :: {{ res_name }}(:) diff --git a/sympleints/templates/fortran_contracted_driver.tpl b/sympleints/templates/fortran_contracted_driver.tpl index 4d41620..83a1f75 100644 --- a/sympleints/templates/fortran_contracted_driver.tpl +++ b/sympleints/templates/fortran_contracted_driver.tpl @@ -1,19 +1,19 @@ subroutine {{ name }}({{ L_args|join(", ") }}, {{ args|join(", ") }}) integer, intent(in) :: {{ Ls|join(", ") }} {{ arg_declaration }} - real(kind=real64), allocatable, dimension({{ res_dim }}) :: res_tmp + real({{ kind }}), allocatable :: res_tmp(:) ! Initializing with => null () adds an implicit save, which will mess ! everything up when running with OpenMP. procedure({{ name }}_proc), pointer :: fncpntr integer :: {{ loop_counter|join(", ") }} - allocate(res_tmp, mold=res) + allocate(res_tmp, mold={{ res_name }}) fncpntr => func_array({{ L_args|join(", ") }})%f - res = 0 + {{ res_name }} = 0 {{ loops|join("\n") }} call fncpntr({{ pntr_args }}, res_tmp) - res = res + res_tmp + {{ res_name }} = {{ res_name }} + res_tmp {% for _ in loops %} end do {% endfor %} diff --git a/sympleints/templates/fortran_equi_func.tpl b/sympleints/templates/fortran_equi_func.tpl new file mode 100644 index 0000000..70dcc3a --- /dev/null +++ b/sympleints/templates/fortran_equi_func.tpl @@ -0,0 +1,9 @@ +subroutine {{ equi_name }} ({{ args|join(", ") }}, {{ res_name }}) + + ! See docstring of {{ name }}. + +{{ arg_declaration }} + +call {{ name }}({{ equi_args|join(", ") }}, {{ res_name }}) +call {{ resort_func }}({{ res_name }}, {{ sizes|join(", ")}}, {{ ncomponents }}) +end subroutine {{ equi_name }} diff --git a/sympleints/templates/fortran_function.tpl b/sympleints/templates/fortran_function.tpl index 3a6ffbf..95ffa5b 100644 --- a/sympleints/templates/fortran_function.tpl +++ b/sympleints/templates/fortran_function.tpl @@ -16,6 +16,6 @@ real({{ kind }}) :: {{ as_.lhs }} {% endfor %} {% for inds, res_line in results_iter %} -{{ res_name }}({{ inds|join(", ") }}) = {{ res_line }} +{{ res_name }}({{ loop.index }}) = {{ res_line }} {% endfor %} end subroutine {{ name }} diff --git a/sympleints/templates/fortran_module.tpl b/sympleints/templates/fortran_module.tpl index 5739974..66a05af 100644 --- a/sympleints/templates/fortran_module.tpl +++ b/sympleints/templates/fortran_module.tpl @@ -11,7 +11,7 @@ use mod_boys, only: boys implicit none type fp - procedure({{ interface_name }}) ,pointer ,nopass :: f =>null() + procedure({{ interface_name }}), pointer, nopass :: f =>null() end type fp interface @@ -26,6 +26,31 @@ type(fp) :: func_array({{ func_arr_dims|join(", ") }}) contains {{ init }} +{% if resort_func == "resort_ba_ab" %} + subroutine resort_ba_ab({{ res_name }}, sizea, sizeb, ncomponents) + real(real64), intent(in out) :: {{ res_name }}(:) + integer, intent(in) :: sizea, sizeb, ncomponents + integer :: sizeab, component_size + integer :: nc, a, b, i, j + real(real64) :: tmp(size({{ res_name }})) + + sizeab = sizea * sizeb + i = 1 + do nc = 1, ncomponents + do b = 1, sizeb + do a = 1, sizea + j = component_size + ((a-1) * sizeb) + b + tmp(j) = {{ res_name }}(i) + i = i + 1 + end do + end do + component_size = component_size + sizeab + end do + + {{ res_name }} = tmp + end subroutine resort_ba_ab +{% endif %} + {{ contr_driver }} {% for func in funcs %}