From be94a97c3a44c76e5c93202a1a49c96cf28b3db2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Faveo=20H=C3=B6rold?= Date: Tue, 28 Oct 2025 00:50:13 +0100 Subject: [PATCH 1/2] Extend const_eval_nodes to propagate parameter arrays. --- dace/frontend/fortran/ast_desugaring.py | 48 ++++++++++++++++++++++--- dace/frontend/fortran/fortran_parser.py | 2 +- tests/fortran/ast_desugaring_test.py | 10 ++++-- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 661872bd2b..c374da0ded 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -29,7 +29,7 @@ Deallocate_Stmt, Close_Stmt, Goto_Stmt, Continue_Stmt, Format_Stmt, Stmt_Function_Stmt, Internal_Subprogram_Part, \ Private_Components_Stmt, Generic_Spec, Language_Binding_Spec, Type_Attr_Spec, Suffix, Proc_Component_Def_Stmt, \ Proc_Decl, End_Type_Stmt, End_Interface_Stmt, Procedure_Declaration_Stmt, Pointer_Assignment_Stmt, Cycle_Stmt, \ - Equiv_Operand, Case_Value_Range_List + Equiv_Operand, Case_Value_Range_List, Ac_Value_List from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt, Error_Stop_Stmt from fparser.two.utils import Base, walk, BinaryOpBase, UnaryOpBase, NumberBase, BlockBase @@ -638,6 +638,45 @@ def _eval_real_literal(x: Union[Signed_Real_Literal_Constant, Real_Literal_Const def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_TYPES]: if isinstance(expr, (Part_Ref, Data_Ref)): + name_node = expr.children[0] + # Only support scalar array accesses for now + if (len(expr.children) > 1 and isinstance(expr.children[1], Section_Subscript_List)): + subsc = expr.children[1] + if (len(subsc.children) == 1): + # TODO index offset correction + idx = _const_eval_basic_type(subsc.children[0], alias_map) + if not idx: + # Array index is not constant + return None + # This is just copied behavior from 'Name' + # But we need to keep track of idx, so we can't just do + # return _const_eval_basic_type(name_node, alias_map) + spec = search_real_local_alias_spec(name_node, alias_map) + if not spec: + # Does not even have a valid identifier. + return None + decl = alias_map[spec] + if not isinstance(decl, Entity_Decl): + # Is not even a data entity. + return None + typ = find_type_of_entity(decl, alias_map) + if not typ or not typ.const: + # Does not have a constant type. + return None + init = atmost_one(children_of_type(decl, Initialization)) + # TODO: Add ref. + _, iexpr = init.children + if f"{iexpr}" == 'NULL()': + # We don't have good representation of "null pointer". + return None + # Expect an Array_Constructor + if (isinstance(iexpr, Array_Constructor)): + # Expect an Ac_Value_List + _, acvall, _ = iexpr.children + if (isinstance(acvall, Ac_Value_List)): + # TODO Bounds check here, but idx is non-normalized anyway so all of this is wrong no matter what + return _const_eval_basic_type(acvall.children[idx-1], alias_map) + # Fail otherwise return None elif isinstance(expr, Name): spec = search_real_local_alias_spec(expr, alias_map) @@ -3623,10 +3662,11 @@ def consolidate_global_data_into_arg(ast: Program, always_add_global_data_arg: b if not spart: continue for tdecl in children_of_type(spart, Type_Declaration_Stmt): - typ, attr, _ = tdecl.children + typ, attr, decl = tdecl.children if 'PARAMETER' in f"{attr}": - # This is a constant which should have been propagated away already. - continue + attr.items = [spec for spec in attr.children if spec.string != 'PARAMETER'] + if not attr.items: + tdecl.items = (typ, None, decl) all_global_vars.append(tdecl.tofortran()) all_derived_types = '\n'.join(all_derived_types) all_global_vars = '\n'.join(all_global_vars) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index f406bdcbd9..296937b3bc 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -3069,7 +3069,7 @@ def run_fparser_transformations(ast: Program, cfg: ParseConfig): if cfg.consolidate_global_data: print("FParser Op: Consolidating the global variables of the AST...") - ast = consolidate_global_data_into_arg(ast) + ast = consolidate_global_data_into_arg(ast, cfg.entry_points) ast = prune_coarsely(ast, cfg.do_not_prune) _checkpoint_ast(cfg, 'ast_v4.f90', ast) diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index b9dff85461..38ec487065 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -1732,14 +1732,17 @@ def test_constant_expression_replacement(): contains subroutine foo implicit none - real :: res1, res2, res3, unk + real :: res1, res2, res3, res4, res5, unk real, parameter :: & x = -(three + 4.0), & y = -4.0, & z = 3.0 + real, parameter :: arr(3) = [3, 4, 5] res1 = unk ** x res2 = unk ** y res3 = unk ** z + res4 = arr(1) + res5 = arr(1) + arr(2) + arr(3) end subroutine foo end module main """).check_with_gfortran().get() @@ -1755,11 +1758,14 @@ def test_constant_expression_replacement(): CONTAINS SUBROUTINE foo IMPLICIT NONE - REAL :: res1, res2, res3, unk + REAL :: res1, res2, res3, res4, res5, unk REAL, PARAMETER :: x = - (7.0), y = - 4.0, z = 3.0 + REAL, PARAMETER :: arr(3) = [3, 4, 5] res1 = unk ** (- 7.0) res2 = unk ** (- 4.0) res3 = unk ** 3.0 + res4 = 3 + res5 = 12 END SUBROUTINE foo END MODULE main """.strip() From c9e57c8d815d8aef84c2c675731eca0f4e74bacb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Faveo=20H=C3=B6rold?= Date: Thu, 30 Oct 2025 02:01:01 +0100 Subject: [PATCH 2/2] Respect array bounds in _const_eval_basic_type. --- dace/frontend/fortran/ast_desugaring.py | 27 ++++++++++++++++++------- dace/frontend/fortran/fortran_parser.py | 2 +- tests/fortran/ast_desugaring_test.py | 6 ++++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index c374da0ded..25d280e498 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -29,7 +29,7 @@ Deallocate_Stmt, Close_Stmt, Goto_Stmt, Continue_Stmt, Format_Stmt, Stmt_Function_Stmt, Internal_Subprogram_Part, \ Private_Components_Stmt, Generic_Spec, Language_Binding_Spec, Type_Attr_Spec, Suffix, Proc_Component_Def_Stmt, \ Proc_Decl, End_Type_Stmt, End_Interface_Stmt, Procedure_Declaration_Stmt, Pointer_Assignment_Stmt, Cycle_Stmt, \ - Equiv_Operand, Case_Value_Range_List, Ac_Value_List + Equiv_Operand, Case_Value_Range_List, Ac_Value_List, Explicit_Shape_Spec_List from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt, Error_Stop_Stmt from fparser.two.utils import Base, walk, BinaryOpBase, UnaryOpBase, NumberBase, BlockBase @@ -639,13 +639,12 @@ def _eval_real_literal(x: Union[Signed_Real_Literal_Constant, Real_Literal_Const def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_TYPES]: if isinstance(expr, (Part_Ref, Data_Ref)): name_node = expr.children[0] - # Only support scalar array accesses for now + # TODO: Support multidimensional array access. if (len(expr.children) > 1 and isinstance(expr.children[1], Section_Subscript_List)): subsc = expr.children[1] if (len(subsc.children) == 1): - # TODO index offset correction idx = _const_eval_basic_type(subsc.children[0], alias_map) - if not idx: + if idx is None: # Array index is not constant return None # This is just copied behavior from 'Name' @@ -659,12 +658,23 @@ def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_ if not isinstance(decl, Entity_Decl): # Is not even a data entity. return None + # Find array declaration bounds + shape = singular(children_of_type(decl, Explicit_Shape_Spec_List)) + shape = singular(children_of_type(shape, Explicit_Shape_Spec)).children + assert len(shape) == 2 + assert shape[1] is not None + lbound = 1 + if shape[0] is not None: + lbound = _const_eval_basic_type(shape[0], alias_map) + ubound = _const_eval_basic_type(shape[1], alias_map) + if lbound is None or ubound is None: + # Shape is not constant + return None typ = find_type_of_entity(decl, alias_map) if not typ or not typ.const: # Does not have a constant type. return None init = atmost_one(children_of_type(decl, Initialization)) - # TODO: Add ref. _, iexpr = init.children if f"{iexpr}" == 'NULL()': # We don't have good representation of "null pointer". @@ -674,8 +684,8 @@ def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_ # Expect an Ac_Value_List _, acvall, _ = iexpr.children if (isinstance(acvall, Ac_Value_List)): - # TODO Bounds check here, but idx is non-normalized anyway so all of this is wrong no matter what - return _const_eval_basic_type(acvall.children[idx-1], alias_map) + assert lbound <= idx and idx <= ubound, f"Array index {idx} is out of bounds in {decl.name}" + return _const_eval_basic_type(acvall.children[idx-lbound], alias_map) # Fail otherwise return None elif isinstance(expr, Name): @@ -3664,6 +3674,9 @@ def consolidate_global_data_into_arg(ast: Program, always_add_global_data_arg: b for tdecl in children_of_type(spart, Type_Declaration_Stmt): typ, attr, decl = tdecl.children if 'PARAMETER' in f"{attr}": + # Parameter arrays cannot always be resolved, the indices may not be constant. + # Keep the array declaration but remove the PARAMETER attribute to avoid + # confusing the internal AST builder later. attr.items = [spec for spec in attr.children if spec.string != 'PARAMETER'] if not attr.items: tdecl.items = (typ, None, decl) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 296937b3bc..f406bdcbd9 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -3069,7 +3069,7 @@ def run_fparser_transformations(ast: Program, cfg: ParseConfig): if cfg.consolidate_global_data: print("FParser Op: Consolidating the global variables of the AST...") - ast = consolidate_global_data_into_arg(ast, cfg.entry_points) + ast = consolidate_global_data_into_arg(ast) ast = prune_coarsely(ast, cfg.do_not_prune) _checkpoint_ast(cfg, 'ast_v4.f90', ast) diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index 38ec487065..be373fedc7 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -1738,11 +1738,14 @@ def test_constant_expression_replacement(): y = -4.0, & z = 3.0 real, parameter :: arr(3) = [3, 4, 5] + real, parameter :: weird(-1:1) = [3, 4, 5] res1 = unk ** x res2 = unk ** y res3 = unk ** z res4 = arr(1) res5 = arr(1) + arr(2) + arr(3) + res4 = weird(0) + res5 = weird(1) + weird(0) - weird(-1) end subroutine foo end module main """).check_with_gfortran().get() @@ -1761,11 +1764,14 @@ def test_constant_expression_replacement(): REAL :: res1, res2, res3, res4, res5, unk REAL, PARAMETER :: x = - (7.0), y = - 4.0, z = 3.0 REAL, PARAMETER :: arr(3) = [3, 4, 5] + REAL, PARAMETER :: weird(- 1 : 1) = [3, 4, 5] res1 = unk ** (- 7.0) res2 = unk ** (- 4.0) res3 = unk ** 3.0 res4 = 3 res5 = 12 + res4 = 4 + res5 = 6 END SUBROUTINE foo END MODULE main """.strip()