From a64a8cfc87c23dc863a55aa9107cee924c110cb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Faveo=20H=C3=B6rold?= Date: Mon, 27 Oct 2025 17:42:22 +0100 Subject: [PATCH 1/3] Add basic unroll_loops fparser transformation. --- dace/frontend/fortran/ast_desugaring.py | 40 +++++++++++++ dace/frontend/fortran/fortran_parser.py | 4 +- tests/fortran/ast_desugaring_test.py | 74 ++++++++++++++++++++++++- 3 files changed, 116 insertions(+), 2 deletions(-) diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 661872bd2b..5bf37c5be5 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -3208,6 +3208,46 @@ def _const_eval_node(n: Base) -> bool: return ast +def unroll_loops(ast: Program) -> Program: + """Unroll loops with static bounds.""" + for node in reversed(walk(ast, (Block_Nonlabel_Do_Construct, Block_Label_Do_Construct))): + do_stmt = node.children[0] + assert isinstance(do_stmt, (Label_Do_Stmt, Nonlabel_Do_Stmt)) + assert isinstance(node.children[-1], End_Do_Stmt) + do_ops = node.children[1:-1] + + loop_control = singular(children_of_type(do_stmt, Loop_Control)) + + _, cntexpr, _, _ = loop_control.children + if cntexpr: + loopvar, looprange = cntexpr + unrollable = True + for rng in looprange: + if not isinstance(rng, Int_Literal_Constant): + # We need the loop range to be constant + unrollable = False + if not unrollable: + continue + assert len(looprange) >= 2 + # Tweak looprange so we can just pass it to Python + for i, rng in enumerate(looprange): + looprange[i] = int(rng.tofortran()) + # Increment 'end', since Python range is exclusive + looprange[1] += 1 + # Add default 'step' 1 if it doesn't exist + if len(looprange) == 2: + looprange.append(1) + unrolled = [] + for i in range(*looprange): + unrolled.append(Assignment_Stmt(f"{loopvar} = {i}")) + for op in do_ops: + # unrolled.append(copy_fparser_node(op)) + unrolled.append(deepcopy(op)) + replace_node(node, unrolled) + + return ast + + @dataclass class ConstTypeInjection: scope_spec: Optional[SPEC] # Only replace within this scope object. diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index f406bdcbd9..becfc9d115 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -34,7 +34,7 @@ make_practically_constant_arguments_constants, exploit_locally_constant_variables, \ assign_globally_unique_subprogram_names, convert_data_statements_into_assignments, \ deconstruct_statement_functions, assign_globally_unique_variable_names, deconstuct_goto_statements, remove_self, \ - prune_coarsely, consolidate_global_data_into_arg, identifier_specs + prune_coarsely, consolidate_global_data_into_arg, identifier_specs, unroll_loops from dace.frontend.fortran.ast_internal_classes import FNode, Main_Program_Node, Name_Node, Var_Decl_Node from dace.frontend.fortran.ast_internal_classes import Program_Node from dace.frontend.fortran.ast_utils import children_of_type, mywalk, atmost_one @@ -3053,6 +3053,8 @@ def run_fparser_transformations(ast: Program, cfg: ParseConfig): print("FParser Op: Fix arguments...") # Fix the practically constant arguments, just in case. ast = make_practically_constant_arguments_constants(ast, cfg.entry_points) + print("FParser Op: Unroll loops...") + ast = unroll_loops(ast) print("FParser Op: Fix local vars...") # Fix the locally constant variables, just in case. ast = exploit_locally_constant_variables(ast) diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index b9dff85461..77873685d3 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -10,7 +10,8 @@ make_practically_constant_arguments_constants, make_practically_constant_global_vars_constants, \ exploit_locally_constant_variables, create_global_initializers, convert_data_statements_into_assignments, \ deconstruct_statement_functions, deconstuct_goto_statements, SPEC, remove_access_and_bind_statements, \ - identifier_specs, alias_specs, consolidate_uses, consolidate_global_data_into_arg, prune_coarsely + identifier_specs, alias_specs, consolidate_uses, consolidate_global_data_into_arg, prune_coarsely, \ + unroll_loops from dace.frontend.fortran.fortran_parser import construct_full_ast from tests.fortran.fortran_test_helper import SourceCodeBuilder @@ -3093,3 +3094,74 @@ def test_constant_function_evaluation(): """.strip() assert got == want SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_unroll_loops(): + """Tests whether basic unrolling transformation works.""" + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + integer :: idx, d + do, idx=1,3 + d = d + idx + end do +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = unroll_loops(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main(d) + IMPLICIT NONE + INTEGER :: idx, d + idx = 1 + d = d + idx + idx = 2 + d = d + idx + idx = 3 + d = d + idx +END SUBROUTINE main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_unroll_loops_nested(): + """Tests whether nested unrolling transformation works.""" + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + integer :: idx, jdx, d + do, idx=1,2 + do, jdx=1,2 + d = d + jdx + end do + d = d + idx + end do +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = unroll_loops(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main(d) + IMPLICIT NONE + INTEGER :: idx, jdx, d + idx = 1 + jdx = 1 + d = d + jdx + jdx = 2 + d = d + jdx + d = d + idx + idx = 2 + jdx = 1 + d = d + jdx + jdx = 2 + d = d + jdx + d = d + idx +END SUBROUTINE main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() From 8e250b3a8b55f3512489914e570ffe0feb559da0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Faveo=20H=C3=B6rold?= Date: Mon, 27 Oct 2025 18:12:01 +0100 Subject: [PATCH 2/3] Make unroll_loops transformation configurable. --- dace/frontend/fortran/ast_desugaring.py | 6 ++- dace/frontend/fortran/fortran_parser.py | 9 ++-- .../fortran/tools/create_preprocessed_ast.py | 7 +++- tests/fortran/ast_desugaring_test.py | 41 +++++++++++++++++-- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 5bf37c5be5..6572b656cf 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -3208,7 +3208,7 @@ def _const_eval_node(n: Base) -> bool: return ast -def unroll_loops(ast: Program) -> Program: +def unroll_loops(ast: Program, max_iter: int = 16) -> Program: """Unroll loops with static bounds.""" for node in reversed(walk(ast, (Block_Nonlabel_Do_Construct, Block_Label_Do_Construct))): do_stmt = node.children[0] @@ -3229,6 +3229,8 @@ def unroll_loops(ast: Program) -> Program: if not unrollable: continue assert len(looprange) >= 2 + # Don't modify the original looprange in case we bail + looprange = looprange.copy() # Tweak looprange so we can just pass it to Python for i, rng in enumerate(looprange): looprange[i] = int(rng.tofortran()) @@ -3237,6 +3239,8 @@ def unroll_loops(ast: Program) -> Program: # Add default 'step' 1 if it doesn't exist if len(looprange) == 2: looprange.append(1) + if max_iter > 0 and len(range(*looprange)) > max_iter: + continue unrolled = [] for i in range(*looprange): unrolled.append(Assignment_Stmt(f"{loopvar} = {i}")) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index becfc9d115..edf9bd19ff 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -2866,7 +2866,8 @@ def __init__(self, sources: Union[None, List[Path], Dict[str, str]] = None, ast_checkpoint_dir: Union[None, str, Path] = None, consolidate_global_data: bool = True, rename_uniquely: bool = True, - do_not_prune_type_components: bool = False): + do_not_prune_type_components: bool = False, + unroll_loops: int = 0): # Make the configs canonical, by processing the various types upfront. if not sources: sources: Dict[str, str] = {} @@ -2906,6 +2907,7 @@ def __init__(self, sources: Union[None, List[Path], Dict[str, str]] = None, self.consolidate_global_data = consolidate_global_data self.rename_uniquely = rename_uniquely self.do_not_prune_type_components = do_not_prune_type_components + self.unroll_loops = unroll_loops # Integer gives maximum number of iterations to unroll. Negative means unroll all. def set_all_possible_entry_points_from(self, ast: Program): # Keep all the possible entry points. @@ -3053,8 +3055,9 @@ def run_fparser_transformations(ast: Program, cfg: ParseConfig): print("FParser Op: Fix arguments...") # Fix the practically constant arguments, just in case. ast = make_practically_constant_arguments_constants(ast, cfg.entry_points) - print("FParser Op: Unroll loops...") - ast = unroll_loops(ast) + if cfg.unroll_loops != 0: + print("FParser Op: Unroll loops...") + ast = unroll_loops(ast, cfg.unroll_loops) print("FParser Op: Fix local vars...") # Fix the locally constant variables, just in case. ast = exploit_locally_constant_variables(ast) diff --git a/dace/frontend/fortran/tools/create_preprocessed_ast.py b/dace/frontend/fortran/tools/create_preprocessed_ast.py index 2c33de4e56..2c7075733d 100644 --- a/dace/frontend/fortran/tools/create_preprocessed_ast.py +++ b/dace/frontend/fortran/tools/create_preprocessed_ast.py @@ -78,6 +78,8 @@ def main(): 'If nothing is given, then will write to STDOUT.') argp.add_argument('--noop', type=str, required=False, action='append', default=[], help='(Optional) Functions or subroutine to make no-op.') + argp.add_argument('--unroll_loops', type=int, required=False, default=0, + help='(Optional) Maximum number of iterations to allow for unrolling loops with constant bounds. Negative input unrolls all loops with constant bounds.') argp.add_argument('-d', '--checkpoint_dir', type=str, required=False, default=None, help='(Optional) If specified, the AST in various stages of preprocessing will be written as' 'Fortran code in there.') @@ -97,6 +99,8 @@ def main(): noops = [tuple(np.split('.')) for np in args.noop] print(f"Will be making these as no-ops: {noops}") + unroll_loops = args.unroll_loops + checkpoint_dir = args.checkpoint_dir if checkpoint_dir: print(f"Will be writing the checkpoint ASTs in: {checkpoint_dir}") @@ -118,7 +122,8 @@ def main(): make_noop=noops, ast_checkpoint_dir=checkpoint_dir, consolidate_global_data=consolidate_global_data, - rename_uniquely=rename_uniquely) + rename_uniquely=rename_uniquely, + unroll_loops=unroll_loops) cfg.sources['_stubs.f90'] = STUBS cfg.sources['_builtins.f90'] = BUILTINS diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index 77873685d3..53eb74c207 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -3102,7 +3102,7 @@ def test_unroll_loops(): subroutine main(d) implicit none integer :: idx, d - do, idx=1,3 + do idx=1,3 d = d + idx end do end subroutine main @@ -3133,8 +3133,8 @@ def test_unroll_loops_nested(): subroutine main(d) implicit none integer :: idx, jdx, d - do, idx=1,2 - do, jdx=1,2 + do idx=1,2 + do jdx=1,2 d = d + jdx end do d = d + idx @@ -3165,3 +3165,38 @@ def test_unroll_loops_nested(): """.strip() assert got == want SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_unroll_loops_bounds(): + """Tests whether loop size check works in unrolling transformation.""" + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + integer :: idx, d + do idx=1,2 + d = d + idx + end do + do idx=1,10,2 + d = d + idx + end do +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = unroll_loops(ast, 4) + + got = ast.tofortran() + want = """ +SUBROUTINE main(d) + IMPLICIT NONE + INTEGER :: idx, d + idx = 1 + d = d + idx + idx = 2 + d = d + idx + DO idx = 1, 10, 2 + d = d + idx + END DO +END SUBROUTINE main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() From 15fbd5031a8fa3b819ba261f9df23d99556527cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Faveo=20H=C3=B6rold?= Date: Tue, 4 Nov 2025 09:44:21 +0100 Subject: [PATCH 3/3] Ignore loops with CYCLE or EXIT statements. --- dace/frontend/fortran/ast_desugaring.py | 3 ++ tests/fortran/ast_desugaring_test.py | 53 +++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 6572b656cf..3348e829ea 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -3211,6 +3211,9 @@ def _const_eval_node(n: Base) -> bool: def unroll_loops(ast: Program, max_iter: int = 16) -> Program: """Unroll loops with static bounds.""" for node in reversed(walk(ast, (Block_Nonlabel_Do_Construct, Block_Label_Do_Construct))): + if walk(node, (Cycle_Stmt, Exit_Stmt)): + # TODO: Handle loop-altering control flow. + continue do_stmt = node.children[0] assert isinstance(do_stmt, (Label_Do_Stmt, Nonlabel_Do_Stmt)) assert isinstance(node.children[-1], End_Do_Stmt) diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index 53eb74c207..8c908fad5f 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -3200,3 +3200,56 @@ def test_unroll_loops_bounds(): """.strip() assert got == want SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_unroll_loops_invalid(): + """Ignore loops with CYCLE or EXIT""" + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + integer :: idx, d + do idx=1,2 + d = d + idx + end do + do idx=2,3 + d = d - 2 + if (d < 1) then + CYCLE + end if + end do + do idx=1,4 + d = d+idx + if (idx > 3) then + EXIT + end if + end do +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = unroll_loops(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main(d) + IMPLICIT NONE + INTEGER :: idx, d + idx = 1 + d = d + idx + idx = 2 + d = d + idx + DO idx = 2, 3 + d = d - 2 + IF (d < 1) THEN + CYCLE + END IF + END DO + DO idx = 1, 4 + d = d + idx + IF (idx > 3) THEN + EXIT + END IF + END DO +END SUBROUTINE main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran()