diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 5cac92969e..a402aa405f 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -2021,9 +2021,33 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): for i in range(len(node.rval.value_list)): assigns.append(ast_internal_classes.BinOp_Node( lval=ast_internal_classes.Array_Subscript_Node(name=node.lval, indices=[ - ast_internal_classes.Int_Literal_Node(value=str(i + 1))], type=node.type, parent=node.parent), + ast_internal_classes.Int_Literal_Node(value=str(i))], type=node.type, parent=node.parent), op="=", rval=node.rval.value_list[i], line_number=node.line_number, parent=node.parent, - typ=node.type)) + type=node.type)) + return ast_internal_classes.Execution_Part_Node(execution=assigns) + return self.generic_visit(node) + + +class ReplaceArrayAssignment(NodeTransformer): + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): + + if isinstance(node.rval, ast_internal_classes.Array_Constructor_Node): + subscript = node.lval + # Find value_list on LHS and point subscript to it. + if isinstance(subscript, ast_internal_classes.Data_Ref_Node): + while True: + if isinstance(subscript, ast_internal_classes.Data_Ref_Node): + subscript = subscript.part_ref + + if isinstance(subscript, ast_internal_classes.Array_Subscript_Node): + break + + assigns = [] + for i in range(len(node.rval.value_list)): + subscript.indices = [ast_internal_classes.Int_Literal_Node(value=str(i))] + new = copy.deepcopy(node) + new.rval = node.rval.value_list[i] + assigns.append(new) return ast_internal_classes.Execution_Part_Node(execution=assigns) return self.generic_visit(node) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index f406bdcbd9..4f5c78a878 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -2192,6 +2192,13 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con :param cfg: The control flow region to which the node should be translated """ + # Transform array assignments into individual assignments + if isinstance(node.rval, ast_internal_classes.Array_Constructor_Node): + # node.lval = node.lval.name + # new_exec = ast_transforms.ReplaceArrayConstructor().visit(node) + new_exec = ast_transforms.ReplaceArrayAssignment().visit(node) + self.translate(new_exec, sdfg, cfg) + return calls = list(mywalk(node, ast_internal_classes.Call_Expr_Node)) if len(calls) == 1: augmented_call = calls[0] diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index f698acf455..647c645e5b 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -258,6 +258,26 @@ def test_pass_an_arrayslice_that_looks_like_a_scalar_from_outside_with_symbolic_ assert d[0] == 65 +def test_array_constructor(): + """Unroll constant array constructors, like for parameter arrays.""" + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(var, idx) + double precision :: d(3) = [3.14, 0.58, 1.41] + double precision :: var(2) + integer :: idx + var(1) = d(idx) + var(2) = d(idx+1) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') + sdfg.simplify(verbose=True) + var = np.full([2], 42, order="F", dtype=np.float64) + idx = np.int32(2) + sdfg(var=var, idx=idx) + assert(var[0] == 0.58) + assert(var[1] == 1.41) + + if __name__ == "__main__": test_fortran_frontend_array_3dmap() test_fortran_frontend_array_access() @@ -266,3 +286,4 @@ def test_pass_an_arrayslice_that_looks_like_a_scalar_from_outside_with_symbolic_ test_fortran_frontend_array_multiple_ranges_with_symbols() test_fortran_frontend_twoconnector() test_fortran_frontend_memlet_in_map_test() + test_array_constructor()