Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3208,6 +3208,53 @@ def _const_eval_node(n: Base) -> bool:
return ast


def unroll_loops(ast: Program, max_iter: int = 16) -> Program:
"""Unroll loops with static bounds."""
for node in reversed(walk(ast, (Block_Nonlabel_Do_Construct, Block_Label_Do_Construct))):
if walk(node, (Cycle_Stmt, Exit_Stmt)):
# TODO: Handle loop-altering control flow.
continue
do_stmt = node.children[0]
assert isinstance(do_stmt, (Label_Do_Stmt, Nonlabel_Do_Stmt))
assert isinstance(node.children[-1], End_Do_Stmt)
do_ops = node.children[1:-1]

loop_control = singular(children_of_type(do_stmt, Loop_Control))

_, cntexpr, _, _ = loop_control.children
if cntexpr:
loopvar, looprange = cntexpr
unrollable = True
for rng in looprange:
if not isinstance(rng, Int_Literal_Constant):
# We need the loop range to be constant
unrollable = False
if not unrollable:
continue
assert len(looprange) >= 2
# Don't modify the original looprange in case we bail
looprange = looprange.copy()
# Tweak looprange so we can just pass it to Python
for i, rng in enumerate(looprange):
looprange[i] = int(rng.tofortran())
# Increment 'end', since Python range is exclusive
looprange[1] += 1
# Add default 'step' 1 if it doesn't exist
if len(looprange) == 2:
looprange.append(1)
if max_iter > 0 and len(range(*looprange)) > max_iter:
continue
unrolled = []
for i in range(*looprange):
unrolled.append(Assignment_Stmt(f"{loopvar} = {i}"))
for op in do_ops:
# unrolled.append(copy_fparser_node(op))
unrolled.append(deepcopy(op))
replace_node(node, unrolled)

return ast


@dataclass
class ConstTypeInjection:
scope_spec: Optional[SPEC] # Only replace within this scope object.
Expand Down
9 changes: 7 additions & 2 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
make_practically_constant_arguments_constants, exploit_locally_constant_variables, \
assign_globally_unique_subprogram_names, convert_data_statements_into_assignments, \
deconstruct_statement_functions, assign_globally_unique_variable_names, deconstuct_goto_statements, remove_self, \
prune_coarsely, consolidate_global_data_into_arg, identifier_specs
prune_coarsely, consolidate_global_data_into_arg, identifier_specs, unroll_loops
from dace.frontend.fortran.ast_internal_classes import FNode, Main_Program_Node, Name_Node, Var_Decl_Node
from dace.frontend.fortran.ast_internal_classes import Program_Node
from dace.frontend.fortran.ast_utils import children_of_type, mywalk, atmost_one
Expand Down Expand Up @@ -2866,7 +2866,8 @@ def __init__(self, sources: Union[None, List[Path], Dict[str, str]] = None,
ast_checkpoint_dir: Union[None, str, Path] = None,
consolidate_global_data: bool = True,
rename_uniquely: bool = True,
do_not_prune_type_components: bool = False):
do_not_prune_type_components: bool = False,
unroll_loops: int = 0):
# Make the configs canonical, by processing the various types upfront.
if not sources:
sources: Dict[str, str] = {}
Expand Down Expand Up @@ -2906,6 +2907,7 @@ def __init__(self, sources: Union[None, List[Path], Dict[str, str]] = None,
self.consolidate_global_data = consolidate_global_data
self.rename_uniquely = rename_uniquely
self.do_not_prune_type_components = do_not_prune_type_components
self.unroll_loops = unroll_loops # Integer gives maximum number of iterations to unroll. Negative means unroll all.

def set_all_possible_entry_points_from(self, ast: Program):
# Keep all the possible entry points.
Expand Down Expand Up @@ -3053,6 +3055,9 @@ def run_fparser_transformations(ast: Program, cfg: ParseConfig):
print("FParser Op: Fix arguments...")
# Fix the practically constant arguments, just in case.
ast = make_practically_constant_arguments_constants(ast, cfg.entry_points)
if cfg.unroll_loops != 0:
print("FParser Op: Unroll loops...")
ast = unroll_loops(ast, cfg.unroll_loops)
print("FParser Op: Fix local vars...")
# Fix the locally constant variables, just in case.
ast = exploit_locally_constant_variables(ast)
Expand Down
7 changes: 6 additions & 1 deletion dace/frontend/fortran/tools/create_preprocessed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def main():
'If nothing is given, then will write to STDOUT.')
argp.add_argument('--noop', type=str, required=False, action='append', default=[],
help='(Optional) Functions or subroutine to make no-op.')
argp.add_argument('--unroll_loops', type=int, required=False, default=0,
help='(Optional) Maximum number of iterations to allow for unrolling loops with constant bounds. Negative input unrolls all loops with constant bounds.')
argp.add_argument('-d', '--checkpoint_dir', type=str, required=False, default=None,
help='(Optional) If specified, the AST in various stages of preprocessing will be written as'
'Fortran code in there.')
Expand All @@ -97,6 +99,8 @@ def main():
noops = [tuple(np.split('.')) for np in args.noop]
print(f"Will be making these as no-ops: {noops}")

unroll_loops = args.unroll_loops

checkpoint_dir = args.checkpoint_dir
if checkpoint_dir:
print(f"Will be writing the checkpoint ASTs in: {checkpoint_dir}")
Expand All @@ -118,7 +122,8 @@ def main():
make_noop=noops,
ast_checkpoint_dir=checkpoint_dir,
consolidate_global_data=consolidate_global_data,
rename_uniquely=rename_uniquely)
rename_uniquely=rename_uniquely,
unroll_loops=unroll_loops)
cfg.sources['_stubs.f90'] = STUBS
cfg.sources['_builtins.f90'] = BUILTINS

Expand Down
162 changes: 161 additions & 1 deletion tests/fortran/ast_desugaring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
make_practically_constant_arguments_constants, make_practically_constant_global_vars_constants, \
exploit_locally_constant_variables, create_global_initializers, convert_data_statements_into_assignments, \
deconstruct_statement_functions, deconstuct_goto_statements, SPEC, remove_access_and_bind_statements, \
identifier_specs, alias_specs, consolidate_uses, consolidate_global_data_into_arg, prune_coarsely
identifier_specs, alias_specs, consolidate_uses, consolidate_global_data_into_arg, prune_coarsely, \
unroll_loops
from dace.frontend.fortran.fortran_parser import construct_full_ast
from tests.fortran.fortran_test_helper import SourceCodeBuilder

Expand Down Expand Up @@ -3093,3 +3094,162 @@ def test_constant_function_evaluation():
""".strip()
assert got == want
SourceCodeBuilder().add_file(got).check_with_gfortran()


def test_unroll_loops():
"""Tests whether basic unrolling transformation works."""
sources, main = SourceCodeBuilder().add_file("""
subroutine main(d)
implicit none
integer :: idx, d
do idx=1,3
d = d + idx
end do
end subroutine main
""").check_with_gfortran().get()
ast = parse_and_improve(sources)
ast = unroll_loops(ast)

got = ast.tofortran()
want = """
SUBROUTINE main(d)
IMPLICIT NONE
INTEGER :: idx, d
idx = 1
d = d + idx
idx = 2
d = d + idx
idx = 3
d = d + idx
END SUBROUTINE main
""".strip()
assert got == want
SourceCodeBuilder().add_file(got).check_with_gfortran()


def test_unroll_loops_nested():
"""Tests whether nested unrolling transformation works."""
sources, main = SourceCodeBuilder().add_file("""
subroutine main(d)
implicit none
integer :: idx, jdx, d
do idx=1,2
do jdx=1,2
d = d + jdx
end do
d = d + idx
end do
end subroutine main
""").check_with_gfortran().get()
ast = parse_and_improve(sources)
ast = unroll_loops(ast)

got = ast.tofortran()
want = """
SUBROUTINE main(d)
IMPLICIT NONE
INTEGER :: idx, jdx, d
idx = 1
jdx = 1
d = d + jdx
jdx = 2
d = d + jdx
d = d + idx
idx = 2
jdx = 1
d = d + jdx
jdx = 2
d = d + jdx
d = d + idx
END SUBROUTINE main
""".strip()
assert got == want
SourceCodeBuilder().add_file(got).check_with_gfortran()


def test_unroll_loops_bounds():
"""Tests whether loop size check works in unrolling transformation."""
sources, main = SourceCodeBuilder().add_file("""
subroutine main(d)
implicit none
integer :: idx, d
do idx=1,2
d = d + idx
end do
do idx=1,10,2
d = d + idx
end do
end subroutine main
""").check_with_gfortran().get()
ast = parse_and_improve(sources)
ast = unroll_loops(ast, 4)

got = ast.tofortran()
want = """
SUBROUTINE main(d)
IMPLICIT NONE
INTEGER :: idx, d
idx = 1
d = d + idx
idx = 2
d = d + idx
DO idx = 1, 10, 2
d = d + idx
END DO
END SUBROUTINE main
""".strip()
assert got == want
SourceCodeBuilder().add_file(got).check_with_gfortran()


def test_unroll_loops_invalid():
"""Ignore loops with CYCLE or EXIT"""
sources, main = SourceCodeBuilder().add_file("""
subroutine main(d)
implicit none
integer :: idx, d
do idx=1,2
d = d + idx
end do
do idx=2,3
d = d - 2
if (d < 1) then
CYCLE
end if
end do
do idx=1,4
d = d+idx
if (idx > 3) then
EXIT
end if
end do
end subroutine main
""").check_with_gfortran().get()
ast = parse_and_improve(sources)
ast = unroll_loops(ast)

got = ast.tofortran()
want = """
SUBROUTINE main(d)
IMPLICIT NONE
INTEGER :: idx, d
idx = 1
d = d + idx
idx = 2
d = d + idx
DO idx = 2, 3
d = d - 2
IF (d < 1) THEN
CYCLE
END IF
END DO
DO idx = 1, 4
d = d + idx
IF (idx > 3) THEN
EXIT
END IF
END DO
END SUBROUTINE main
""".strip()
assert got == want
SourceCodeBuilder().add_file(got).check_with_gfortran()