Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Deallocate_Stmt, Close_Stmt, Goto_Stmt, Continue_Stmt, Format_Stmt, Stmt_Function_Stmt, Internal_Subprogram_Part, \
Private_Components_Stmt, Generic_Spec, Language_Binding_Spec, Type_Attr_Spec, Suffix, Proc_Component_Def_Stmt, \
Proc_Decl, End_Type_Stmt, End_Interface_Stmt, Procedure_Declaration_Stmt, Pointer_Assignment_Stmt, Cycle_Stmt, \
Equiv_Operand, Case_Value_Range_List
Equiv_Operand, Case_Value_Range_List, Ac_Value_List, Explicit_Shape_Spec_List
from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt, Error_Stop_Stmt
from fparser.two.utils import Base, walk, BinaryOpBase, UnaryOpBase, NumberBase, BlockBase

Expand Down Expand Up @@ -638,6 +638,55 @@ def _eval_real_literal(x: Union[Signed_Real_Literal_Constant, Real_Literal_Const

def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_TYPES]:
if isinstance(expr, (Part_Ref, Data_Ref)):
name_node = expr.children[0]
# TODO: Support multidimensional array access.
if (len(expr.children) > 1 and isinstance(expr.children[1], Section_Subscript_List)):
subsc = expr.children[1]
if (len(subsc.children) == 1):
idx = _const_eval_basic_type(subsc.children[0], alias_map)
if idx is None:
# Array index is not constant
return None
# This is just copied behavior from 'Name'
# But we need to keep track of idx, so we can't just do
# return _const_eval_basic_type(name_node, alias_map)
spec = search_real_local_alias_spec(name_node, alias_map)
if not spec:
# Does not even have a valid identifier.
return None
decl = alias_map[spec]
if not isinstance(decl, Entity_Decl):
# Is not even a data entity.
return None
# Find array declaration bounds
shape = singular(children_of_type(decl, Explicit_Shape_Spec_List))
shape = singular(children_of_type(shape, Explicit_Shape_Spec)).children
assert len(shape) == 2
assert shape[1] is not None
lbound = 1
if shape[0] is not None:
lbound = _const_eval_basic_type(shape[0], alias_map)
ubound = _const_eval_basic_type(shape[1], alias_map)
if lbound is None or ubound is None:
# Shape is not constant
return None
typ = find_type_of_entity(decl, alias_map)
if not typ or not typ.const:
# Does not have a constant type.
return None
init = atmost_one(children_of_type(decl, Initialization))
_, iexpr = init.children
if f"{iexpr}" == 'NULL()':
# We don't have good representation of "null pointer".
return None
# Expect an Array_Constructor
if (isinstance(iexpr, Array_Constructor)):
# Expect an Ac_Value_List
_, acvall, _ = iexpr.children
if (isinstance(acvall, Ac_Value_List)):
assert lbound <= idx and idx <= ubound, f"Array index {idx} is out of bounds in {decl.name}"
return _const_eval_basic_type(acvall.children[idx-lbound], alias_map)
# Fail otherwise
return None
elif isinstance(expr, Name):
spec = search_real_local_alias_spec(expr, alias_map)
Expand Down Expand Up @@ -3623,10 +3672,14 @@ def consolidate_global_data_into_arg(ast: Program, always_add_global_data_arg: b
if not spart:
continue
for tdecl in children_of_type(spart, Type_Declaration_Stmt):
typ, attr, _ = tdecl.children
typ, attr, decl = tdecl.children
if 'PARAMETER' in f"{attr}":
# This is a constant which should have been propagated away already.
continue
# Parameter arrays cannot always be resolved, the indices may not be constant.
# Keep the array declaration but remove the PARAMETER attribute to avoid
# confusing the internal AST builder later.
attr.items = [spec for spec in attr.children if spec.string != 'PARAMETER']
if not attr.items:
tdecl.items = (typ, None, decl)
all_global_vars.append(tdecl.tofortran())
all_derived_types = '\n'.join(all_derived_types)
all_global_vars = '\n'.join(all_global_vars)
Expand Down
16 changes: 14 additions & 2 deletions tests/fortran/ast_desugaring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,14 +1732,20 @@ def test_constant_expression_replacement():
contains
subroutine foo
implicit none
real :: res1, res2, res3, unk
real :: res1, res2, res3, res4, res5, unk
real, parameter :: &
x = -(three + 4.0), &
y = -4.0, &
z = 3.0
real, parameter :: arr(3) = [3, 4, 5]
real, parameter :: weird(-1:1) = [3, 4, 5]
res1 = unk ** x
res2 = unk ** y
res3 = unk ** z
res4 = arr(1)
res5 = arr(1) + arr(2) + arr(3)
res4 = weird(0)
res5 = weird(1) + weird(0) - weird(-1)
end subroutine foo
end module main
""").check_with_gfortran().get()
Expand All @@ -1755,11 +1761,17 @@ def test_constant_expression_replacement():
CONTAINS
SUBROUTINE foo
IMPLICIT NONE
REAL :: res1, res2, res3, unk
REAL :: res1, res2, res3, res4, res5, unk
REAL, PARAMETER :: x = - (7.0), y = - 4.0, z = 3.0
REAL, PARAMETER :: arr(3) = [3, 4, 5]
REAL, PARAMETER :: weird(- 1 : 1) = [3, 4, 5]
res1 = unk ** (- 7.0)
res2 = unk ** (- 4.0)
res3 = unk ** 3.0
res4 = 3
res5 = 12
res4 = 4
res5 = 6
END SUBROUTINE foo
END MODULE main
""".strip()
Expand Down