Skip to content

Commit

Permalink
mod: added 2d integral resorting to Fortran
Browse files Browse the repository at this point in the history
  • Loading branch information
eljost committed Feb 27, 2024
1 parent aec9dc3 commit cf6cb79
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 79 deletions.
140 changes: 68 additions & 72 deletions sympleints/FortranRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sympy.printing.fortran import FCodePrinter

from sympleints.Renderer import Renderer
from sympleints.helpers import shell_shape_iter
from sympleints.helpers import shell_shape_iter, get_reorder_inds

Check failure on line 10 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

sympleints/FortranRenderer.py:10:50: F401 `sympleints.helpers.get_reorder_inds` imported but unused

Check failure on line 10 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F401)

sympleints/FortranRenderer.py:10:50: F401 `sympleints.helpers.get_reorder_inds` imported but unused

Check failure on line 10 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

sympleints/FortranRenderer.py:10:50: F401 `sympleints.helpers.get_reorder_inds` imported but unused

Check failure on line 10 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (F401)

sympleints/FortranRenderer.py:10:50: F401 `sympleints.helpers.get_reorder_inds` imported but unused


def format_with_fprettify(fortran: str):
Expand Down Expand Up @@ -41,22 +41,6 @@ def format_with_fprettify(fortran: str):
class FCodePrinterMod(FCodePrinter):
boys_re = re.compile(r"boys\(([d\d\.]+),(.+)")

"""
def _print_Function(self, expr):
func = super()._print_Function(expr)
# Sympy prints everything as float (1.0d0, 2.0d0 etc.), even integers, but the
# first argument 'n' to the Boys function must be integer.
if func.startswith("boys"):
# Only try to fix calls to the Boys function when a double is present
if mobj := self.boys_re.match(func):
as_float = float(mobj.group(1).lower().replace("d", "e"))
as_int = int(as_float)
assert abs(as_float - as_int) <= 1e-14
remainder = mobj.group(2)
func = f"boys({as_int},{remainder}"
return func
"""

def _print_AppliedUndef(self, expr):
"""For printing the Boys function.
Expand Down Expand Up @@ -99,11 +83,15 @@ def make_fortran_comment(comment_str):
class FortranRenderer(Renderer):
ext = ".f90"
real_kind = "kind=real64"
res_name = "res"
res_name = "result"
language = "Fortran"

_primitive = True
_drop_dim = False
resort_func_dict = {
2: ("resort_ba_ab", (2, 1)),
3: ("resort_bac_abc", (2, 1)),
}

def shell_shape_iter(self, *args, **kwargs):
# Start indexing at 1, instead of 0.
Expand All @@ -121,16 +109,27 @@ def get_argument_declaration(self, functions, contracted=False):
coeffs=functions.coeffs,
zip=zip,
centers=functions.centers,
ref_center=functions.ref_center
if (contracted or functions.with_ref_center)
else None,
ref_center=(
functions.ref_center
if (contracted or functions.with_ref_center)
else None
),
res_name=self.res_name,
res_dim=res_dim,
)
return rendered

def render_function(
self, functions, repls, reduced, shape, shape_iter, args, name, doc_str=""
self,
functions,
repls,
reduced,
shape,
shape_iter,
args,
name,
L_tots,
doc_str="",
):
if (not functions.primitive) or (not self._primitive):
warnings.warn("FortranRenderer always produces subroutines for primitives!")
Expand Down Expand Up @@ -161,60 +160,51 @@ def render_function(
)
return rendered

"""
def render_function(
self, functions, repls, reduced, shape, shape_iter, args, name, doc_str=""
def render_equi_function(
self,
functions,
name,
equi_name,
equi_inds,
shape,
):
# This allows using the 'boys' function without producing an error
print_settings = {
"allow_unknown_functions": True,
# Without disabling contract some expressions will raise ValueError.
"contract": False,
"standard": 2008,
"source_format": "free",
}
print_func = FCodePrinterMod(print_settings).doprint
assignments = [Assignment(lhs, rhs) for lhs, rhs in repls]
repl_lines = [print_func(as_) for as_ in assignments]
results = [print_func(red) for red in reduced]
res_len = len(reduced)
results_iter = zip(shape_iter, results)
"""
tpl = self.get_template(fn="fortran_equi_func.tpl")
nbfs = functions.nbfs
assert nbfs in (2,), "Implement other cases in template!"
# shape refers to the shape of the original function that is used to generate
# the equivalent function here.
reorder_inds = (
get_reorder_inds(
shape,
functions.ncomponents,
herm_axes=functions.hermitian,
)
+ 1 # Add one as Fortran array indices start at 1.
)
"""

tmps = [lhs for lhs, rhs in repls]
from sympy import IndexedBase
tmp = IndexedBase('tmp', shape=len(tmps))
map_ = {}
for i, rhs_ in enumerate(tmps, 1):
map_[rhs_] = tmp[i - 1]
rhs_subs = list()
for _, rhs in repls:
rhs_subs.append(rhs.subs(map_))
red_subs = list()
for red in reduced:
red_subs.append(red.subs(map_))
results_sub = [print_func(red) for red in red_subs]
results_iter = zip(shape_iter, results_sub)
repl_rhss = [print_func(rhs) for rhs in rhs_subs]
doc_str = make_fortran_comment(doc_str)
arg_declaration = self.get_argument_declaration(functions)
tpl = self.get_template(fn="fortran_function_arr.tpl")
resort_func, new_order = self.resort_func_dict[functions.nbfs]
sizes = shape
sizes = [shape[i] for i in new_order]
equi_args = ", ".join(functions.full_args_for_bf_inds(equi_inds))

Check failure on line 191 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F841)

sympleints/FortranRenderer.py:191:9: F841 Local variable `equi_args` is assigned to but never used

Check failure on line 191 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F841)

sympleints/FortranRenderer.py:191:9: F841 Local variable `equi_args` is assigned to but never used

Check failure on line 191 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F841)

sympleints/FortranRenderer.py:191:9: F841 Local variable `equi_args` is assigned to but never used

Check failure on line 191 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (F841)

sympleints/FortranRenderer.py:191:9: F841 Local variable `equi_args` is assigned to but never used
args = ", ".join(functions.full_args)

Check failure on line 192 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F841)

sympleints/FortranRenderer.py:192:9: F841 Local variable `args` is assigned to but never used

Check failure on line 192 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F841)

sympleints/FortranRenderer.py:192:9: F841 Local variable `args` is assigned to but never used

Check failure on line 192 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F841)

sympleints/FortranRenderer.py:192:9: F841 Local variable `args` is assigned to but never used

Check failure on line 192 in sympleints/FortranRenderer.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (F841)

sympleints/FortranRenderer.py:192:9: F841 Local variable `args` is assigned to but never used
ncomponents = functions.ncomponents

tpl = self.get_template(fn="fortran_equi_func.tpl")
rendered = tpl.render(
name=name,
args=functions.full_args,
doc_str=doc_str,
arg_declaration=arg_declaration,
res_name=self.res_name,
res_len=res_len,
assignments=assignments,
repl_lines=repl_lines,
results_iter=results_iter,
reduced=reduced,
kind=self.real_kind,
repl_rhs=repl_rhss,
)
res_name=self.res_name,
arg_declaration=arg_declaration,
equi_name=equi_name,
equi_args=functions.full_args_for_bf_inds(equi_inds),
name=name,
args=functions.full_args,
sizes=sizes,
resort_func=resort_func,
ncomponents=ncomponents,
)
return rendered
"""

def render_f_init(self, name, rendered_funcs, func_array_name="func_array"):
tpl = self.get_template(fn="fortran_init.tpl")
Expand Down Expand Up @@ -261,6 +251,8 @@ def loop_iter(counters, exps, coeffs):
loops=loops,
pntr_args=pntr_args,
args=args,
kind=self.real_kind,
res_name=self.res_name,
)
return rendered

Expand All @@ -277,19 +269,23 @@ def render_module(self, functions, rendered_funcs, **tpl_kwargs):
l_max = functions.l_max
func_arr_dims = [f"0:{l_max}" for _ in range(functions.nbfs)]
contr_driver = self.render_contracted_driver(functions)
resort_func, _ = self.resort_func_dict[functions.nbfs]

tpl = self.get_template(fn="fortran_module.tpl")
rendered = tpl.render(
header=header,
mod_name=mod_name,
boys=functions.boys,
boys=functions.boys_func,
args=functions.full_args + [self.res_name],
interface_name=interface_name,
contr_driver=contr_driver,
arg_declaration=arg_declaration,
comment=comment,
func_arr_dims=func_arr_dims,
init=init,
res_name=self.res_name,
resort_func=resort_func,
kind=self.real_kind,
funcs=rendered_funcs,
)
rendered = format_with_fprettify(rendered)
Expand Down
2 changes: 1 addition & 1 deletion sympleints/templates/fortran_arg_declaration.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ real({{ kind }}), intent(in), dimension(3) :: {{ centers|join(", ") }}
real({{ kind }}), intent(in), dimension(3) :: {{ ref_center }}
{% endif %}
! Return value
real({{ kind }}), intent(in out) :: {{ res_name }}({{ res_dim }})
real({{ kind }}), intent(in out) :: {{ res_name }}(:)
8 changes: 4 additions & 4 deletions sympleints/templates/fortran_contracted_driver.tpl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
subroutine {{ name }}({{ L_args|join(", ") }}, {{ args|join(", ") }})
integer, intent(in) :: {{ Ls|join(", ") }}
{{ arg_declaration }}
real(kind=real64), allocatable, dimension({{ res_dim }}) :: res_tmp
real({{ kind }}), allocatable :: res_tmp(:)
! Initializing with => null () adds an implicit save, which will mess
! everything up when running with OpenMP.
procedure({{ name }}_proc), pointer :: fncpntr
integer :: {{ loop_counter|join(", ") }}

allocate(res_tmp, mold=res)
allocate(res_tmp, mold={{ res_name }})
fncpntr => func_array({{ L_args|join(", ") }})%f

res = 0
{{ res_name }} = 0
{{ loops|join("\n") }}
call fncpntr({{ pntr_args }}, res_tmp)
res = res + res_tmp
{{ res_name }} = {{ res_name }} + res_tmp
{% for _ in loops %}
end do
{% endfor %}
Expand Down
9 changes: 9 additions & 0 deletions sympleints/templates/fortran_equi_func.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
subroutine {{ equi_name }} ({{ args|join(", ") }}, {{ res_name }})

! See docstring of {{ name }}.

{{ arg_declaration }}

call {{ name }}({{ equi_args|join(", ") }}, {{ res_name }})
call {{ resort_func }}({{ res_name }}, {{ sizes|join(", ")}}, {{ ncomponents }})
end subroutine {{ equi_name }}
2 changes: 1 addition & 1 deletion sympleints/templates/fortran_function.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ real({{ kind }}) :: {{ as_.lhs }}
{% endfor %}

{% for inds, res_line in results_iter %}
{{ res_name }}({{ inds|join(", ") }}) = {{ res_line }}
{{ res_name }}({{ loop.index }}) = {{ res_line }}
{% endfor %}
end subroutine {{ name }}
27 changes: 26 additions & 1 deletion sympleints/templates/fortran_module.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use mod_boys, only: boys
implicit none

type fp
procedure({{ interface_name }}) ,pointer ,nopass :: f =>null()
procedure({{ interface_name }}), pointer, nopass :: f =>null()
end type fp

interface
Expand All @@ -26,6 +26,31 @@ type(fp) :: func_array({{ func_arr_dims|join(", ") }})
contains
{{ init }}

{% if resort_func == "resort_ba_ab" %}
subroutine resort_ba_ab({{ res_name }}, sizea, sizeb, ncomponents)
real(real64), intent(in out) :: {{ res_name }}(:)
integer, intent(in) :: sizea, sizeb, ncomponents
integer :: sizeab, component_size
integer :: nc, a, b, i, j
real(real64) :: tmp(size({{ res_name }}))

sizeab = sizea * sizeb
i = 1
do nc = 1, ncomponents
do b = 1, sizeb
do a = 1, sizea
j = component_size + ((a-1) * sizeb) + b
tmp(j) = {{ res_name }}(i)
i = i + 1
end do
end do
component_size = component_size + sizeab
end do

{{ res_name }} = tmp
end subroutine resort_ba_ab
{% endif %}

{{ contr_driver }}

{% for func in funcs %}
Expand Down

0 comments on commit cf6cb79

Please sign in to comment.