Skip to content

Commit

Permalink
mod: use nested loops for resorting integrals
Browse files Browse the repository at this point in the history
  • Loading branch information
eljost committed Feb 27, 2024
1 parent d56f5f2 commit aec9dc3
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 44 deletions.
4 changes: 2 additions & 2 deletions sympleints/Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Functions:
doc_func: Optional[Callable] = None
comment: str = ""
boys_func: Optional[str] = None
ncomponents: int = 0
ncomponents: int = 1
with_ref_center: bool = True
full_name: Optional["str"] = None
l_aux_max: Optional[int] = None
Expand All @@ -39,7 +39,7 @@ class Functions:
def __post_init__(self):
assert self.l_max >= 0
assert len(self.coeffs) == len(self.exponents) == len(self.centers)
assert self.ncomponents >= 0
assert self.ncomponents >= 1

L_iter, inner = self.ls_exprs
self.L_iter = L_iter
Expand Down
28 changes: 15 additions & 13 deletions sympleints/PythonRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class PythonRenderer(Renderer):
"module": "py_module.tpl",
}
_primitive = False
_drop_dim = True
_drop_dim = True # TODO: remove this?!
resort_func_dict = {
2: ("resort_ba_ab", (2, 1)),
3: ("resort_bac_abc", (2, 1)),
}

def render_function(
self,
Expand Down Expand Up @@ -84,27 +88,23 @@ def render_equi_function(
equi_name,
equi_inds,
shape,
from_axes,
to_axes,
):
resort_func, new_order = self.resort_func_dict[functions.nbfs]
sizes = shape
sizes = [shape[i] for i in new_order]
equi_args = ", ".join(functions.full_args_for_bf_inds(equi_inds))
args = ", ".join(functions.full_args)
ncomponents = functions.ncomponents

tpl = self.get_template(key="equi_func")
nbfs = functions.nbfs
assert nbfs in (2, 3), "Implement other cases in template!"
if functions.ncomponents == 1:
from_axes = tuple([i - 1 for i in from_axes[1:]])
to_axes = tuple([i - 1 for i in to_axes[1:]])
rendered = tpl.render(
equi_name=equi_name,
equi_args=equi_args,
name=name,
args=args,
from_axes=from_axes,
to_axes=to_axes,
shape=shape,
functions=functions,
primitive=functions.primitive,
sizes=sizes,
resort_func=resort_func,
ncomponents=ncomponents,
)
return rendered

Expand All @@ -115,6 +115,7 @@ def render_func_dict(self, name, rendered_funcs):

def render_module(self, functions, rendered_funcs, **tpl_kwargs):
func_dict = self.render_func_dict(functions.name, rendered_funcs)
resort_func, _ = self.resort_func_dict[functions.nbfs]
tpl = self.get_template(key="module")
_tpl_kwargs = {
"header": functions.header,
Expand All @@ -124,6 +125,7 @@ def render_module(self, functions, rendered_funcs, **tpl_kwargs):
"func_dict": func_dict,
"name": functions.name,
"args": functions.full_args,
"resort_func": resort_func,
}
_tpl_kwargs.update(tpl_kwargs)
rendered = tpl.render(**_tpl_kwargs)
Expand Down
25 changes: 2 additions & 23 deletions sympleints/Renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def render_equi_function(
def render_functions(self, functions: Functions):
args = functions.full_args
ncomponents = functions.ncomponents
assert ncomponents > 0
if len(hermi_inds := functions.hermitian) >= 2:
hermi_inds = functions.hermitian
assert len(hermi_inds) == 2
Expand Down Expand Up @@ -124,34 +125,12 @@ def render_functions(self, functions: Functions):
L_tots_equi = tuple(L_tots_equi)
name_equi = func_name_from_Ls(functions.name, L_tots_equi)

from_axes = tuple(equi_inds)
to_axes = tuple(org_inds)

nbfs = functions.nbfs
# Axes/indices are missing
if len(from_axes) != nbfs:
to_axes = tuple(range(nbfs))
from_axes = list(range(nbfs))
from_axes[hi], from_axes[hj] = from_axes[hj], from_axes[hi]
from_axes = tuple(from_axes)

more_components = functions.ncomponents > 1
add_axis = more_components or (not self._drop_dim and nbfs == 2)
if add_axis:
from_axes = tuple([0] + [fa + 1 for fa in from_axes])
to_axes = tuple([0] + [ta + 1 for ta in to_axes])

# Drop first dimension when only 1 or 0 components are present
reshape = shape if more_components else shape[1:]

func_equi = self.render_equi_function(
functions,
name,
name_equi,
equi_inds,
reshape,
from_axes,
to_axes,
shape,
)
rendered_funcs.append(
RenderedFunction(name=name_equi, Ls=L_tots_equi, text=func_equi)
Expand Down
7 changes: 3 additions & 4 deletions sympleints/templates/numba_equi_func.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def {{ equi_name }}({{ args }}, result):
"""See docstring of {{ name }}."""

# Call equivalent function and write to result
tmp = numpy.zeros_like(result)
{{ name }}({{ equi_args }}, tmp)
result[:] {% if not primitive %}+{% endif %}= numpy.transpose(tmp.reshape({{ shape|join(", ") }}), axes={{ from_axes }}).flatten()

# tmp = numpy.zeros_like(result)
{{ name }}({{ equi_args }}, result)
{{ resort_func }}(result, {{ sizes|join(", ")}}, {{ ncomponents }})
23 changes: 23 additions & 0 deletions sympleints/templates/numba_module.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ from pysisyphus.wavefunction.ints.boys import boys
{{ ai }}
{% endfor %}

{% if resort_func == "resort_ba_ab" %}
@numba.jit(
nopython=True,
nogil=True,
fastmath=True,
cache=True,
)
def resort_ba_ab(result, sizea, sizeb, ncomponents):
assert ncomponents > 0
tmp = numpy.zeros_like(result)
i = 0 # Original index
sizeab = sizea * sizeb
component_size = 0
for _ in range(ncomponents):
for b in range(sizeb):
for a in range(sizea):
j = component_size + (a * sizeb) + b # New index
tmp[j] = result[i]
i += 1
component_size += sizeab
result[:] = tmp
{% endif %}

func_type = numba.types.FunctionType(
{{ func_type }}
)
Expand Down
3 changes: 1 addition & 2 deletions sympleints/templates/py_equi_func.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ def {{ equi_name }}({{ args }}, result):

# Calculate values w/ swapped arguments
{{ name }}({{ equi_args }}, result)
# Swap two axes
result[:] = numpy.moveaxis(result.reshape({{ shape|join(",") }}), {{ from_axes }}, {{ to_axes }}).flatten()
{{ resort_func }}(result, {{ sizes|join(", ")}}, {{ ncomponents }})
17 changes: 17 additions & 0 deletions sympleints/templates/py_module.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@ from {{ boys }} import boys
{{ ai }}
{% endfor %}

{% if resort_func == "resort_ba_ab" %}
def resort_ba_ab(result, sizea, sizeb, ncomponents):
assert ncomponents > 0
tmp = numpy.zeros_like(result)
i = 0 # Original index
sizeab = sizea * sizeb
component_size = 0
for _ in range(ncomponents):
for b in range(sizeb):
for a in range(sizea):
j = component_size + (a * sizeb) + b # New index
tmp[j] = result[i]
i += 1
component_size += sizeab
result[:] = tmp
{% endif %}

{% for func in funcs %}
{{ func.text }}
{% endfor %}
Expand Down

0 comments on commit aec9dc3

Please sign in to comment.