diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 661872bd2b..25d280e498 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -29,7 +29,7 @@ Deallocate_Stmt, Close_Stmt, Goto_Stmt, Continue_Stmt, Format_Stmt, Stmt_Function_Stmt, Internal_Subprogram_Part, \ Private_Components_Stmt, Generic_Spec, Language_Binding_Spec, Type_Attr_Spec, Suffix, Proc_Component_Def_Stmt, \ Proc_Decl, End_Type_Stmt, End_Interface_Stmt, Procedure_Declaration_Stmt, Pointer_Assignment_Stmt, Cycle_Stmt, \ - Equiv_Operand, Case_Value_Range_List + Equiv_Operand, Case_Value_Range_List, Ac_Value_List, Explicit_Shape_Spec_List from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt, Error_Stop_Stmt from fparser.two.utils import Base, walk, BinaryOpBase, UnaryOpBase, NumberBase, BlockBase @@ -638,6 +638,55 @@ def _eval_real_literal(x: Union[Signed_Real_Literal_Constant, Real_Literal_Const def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_TYPES]: if isinstance(expr, (Part_Ref, Data_Ref)): + name_node = expr.children[0] + # TODO: Support multidimensional array access. + if (len(expr.children) > 1 and isinstance(expr.children[1], Section_Subscript_List)): + subsc = expr.children[1] + if (len(subsc.children) == 1): + idx = _const_eval_basic_type(subsc.children[0], alias_map) + if idx is None: + # Array index is not constant + return None + # This is just copied behavior from 'Name' + # But we need to keep track of idx, so we can't just do + # return _const_eval_basic_type(name_node, alias_map) + spec = search_real_local_alias_spec(name_node, alias_map) + if not spec: + # Does not even have a valid identifier. + return None + decl = alias_map[spec] + if not isinstance(decl, Entity_Decl): + # Is not even a data entity. + return None + # Find array declaration bounds + shape = singular(children_of_type(decl, Explicit_Shape_Spec_List)) + shape = singular(children_of_type(shape, Explicit_Shape_Spec)).children + assert len(shape) == 2 + assert shape[1] is not None + lbound = 1 + if shape[0] is not None: + lbound = _const_eval_basic_type(shape[0], alias_map) + ubound = _const_eval_basic_type(shape[1], alias_map) + if lbound is None or ubound is None: + # Shape is not constant + return None + typ = find_type_of_entity(decl, alias_map) + if not typ or not typ.const: + # Does not have a constant type. + return None + init = atmost_one(children_of_type(decl, Initialization)) + _, iexpr = init.children + if f"{iexpr}" == 'NULL()': + # We don't have good representation of "null pointer". + return None + # Expect an Array_Constructor + if (isinstance(iexpr, Array_Constructor)): + # Expect an Ac_Value_List + _, acvall, _ = iexpr.children + if (isinstance(acvall, Ac_Value_List)): + assert lbound <= idx and idx <= ubound, f"Array index {idx} is out of bounds in {decl.name}" + return _const_eval_basic_type(acvall.children[idx-lbound], alias_map) + # Fail otherwise return None elif isinstance(expr, Name): spec = search_real_local_alias_spec(expr, alias_map) @@ -3623,10 +3672,14 @@ def consolidate_global_data_into_arg(ast: Program, always_add_global_data_arg: b if not spart: continue for tdecl in children_of_type(spart, Type_Declaration_Stmt): - typ, attr, _ = tdecl.children + typ, attr, decl = tdecl.children if 'PARAMETER' in f"{attr}": - # This is a constant which should have been propagated away already. - continue + # Parameter arrays cannot always be resolved, the indices may not be constant. + # Keep the array declaration but remove the PARAMETER attribute to avoid + # confusing the internal AST builder later. + attr.items = [spec for spec in attr.children if spec.string != 'PARAMETER'] + if not attr.items: + tdecl.items = (typ, None, decl) all_global_vars.append(tdecl.tofortran()) all_derived_types = '\n'.join(all_derived_types) all_global_vars = '\n'.join(all_global_vars) diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index b9dff85461..be373fedc7 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -1732,14 +1732,20 @@ def test_constant_expression_replacement(): contains subroutine foo implicit none - real :: res1, res2, res3, unk + real :: res1, res2, res3, res4, res5, unk real, parameter :: & x = -(three + 4.0), & y = -4.0, & z = 3.0 + real, parameter :: arr(3) = [3, 4, 5] + real, parameter :: weird(-1:1) = [3, 4, 5] res1 = unk ** x res2 = unk ** y res3 = unk ** z + res4 = arr(1) + res5 = arr(1) + arr(2) + arr(3) + res4 = weird(0) + res5 = weird(1) + weird(0) - weird(-1) end subroutine foo end module main """).check_with_gfortran().get() @@ -1755,11 +1761,17 @@ def test_constant_expression_replacement(): CONTAINS SUBROUTINE foo IMPLICIT NONE - REAL :: res1, res2, res3, unk + REAL :: res1, res2, res3, res4, res5, unk REAL, PARAMETER :: x = - (7.0), y = - 4.0, z = 3.0 + REAL, PARAMETER :: arr(3) = [3, 4, 5] + REAL, PARAMETER :: weird(- 1 : 1) = [3, 4, 5] res1 = unk ** (- 7.0) res2 = unk ** (- 4.0) res3 = unk ** 3.0 + res4 = 3 + res5 = 12 + res4 = 4 + res5 = 6 END SUBROUTINE foo END MODULE main """.strip()