diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 836070c1f4..5f42c7502f 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -779,7 +779,7 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: # If there's nothing inside the branch, add a noop state to get a valid SDFG and let simplify take # care of the rest. case_body.add_state('noop', is_start_block=True) - else: + else: name = f"Conditional_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" prev_block = None if cfg not in self.last_sdfg_states else self.last_sdfg_states[cfg] @@ -1071,16 +1071,16 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, # FIXME: Dirty hack to let translator create clean SDFG state names if node.line_number == -1: node.line_number = (0, 0) - if isinstance(litval, ast_internal_classes.Int_Literal_Node): + if isinstance(litval, ast_internal_classes.Int_Literal_Node): sym_dict[local_name.name] = litval.value new_sdfg.add_symbol(local_name.name, dtypes.int32) - else: + else: assigns.append( ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name=local_name.name), rval=litval, op="=", line_number=node.line_number)) - + # This handles the case where the function is called with symbols for parameter, symbol in symbol_arguments: sym_dict[parameter.name] = symbol.name @@ -1107,7 +1107,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, needs_replacement = {} for variable_in_call in variables_in_call: local_name = parameters[variables_in_call.index(variable_in_call)] - + local_definition = namefinder.specs.get(local_name.name) if local_definition is None: raise ValueError("Variable " + local_name.name + " is not defined in the function") @@ -1134,7 +1134,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, for i in sdfg.symbols: sym_dict[i] = i - sym_dict.update(self.temporary_sym_dict[new_sdfg.name]) + sym_dict.update(self.temporary_sym_dict[new_sdfg.name]) not_found_write_names = [] not_found_read_names = [] @@ -1158,7 +1158,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, addedmemlets = [] globalmemlets = [] - + # This handles the case where the function is called with read variables found in a module cached_names = [a[0] for a in self.module_vars] for i in not_found_read_names: @@ -1504,7 +1504,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if is_start: first_substate = new_sdfg.add_state("start_state", is_start_block=True) self.last_sdfg_states[new_sdfg] = first_substate - + substate = new_sdfg.add_state("dummy_state_for_symbol_init") entries={} for i in symbol_assigns: @@ -1641,7 +1641,7 @@ def add_full_object(self, new_sdfg: SDFG,sdfg:SDFG, array: dat.Array, local_name pass else: #raise ValueError("Shape of array does not match") - local_shape,local_strides = self.fix_shapes_before_adding_from_nested(sdfg,new_sdfg,local_shape,local_strides) + local_shape,local_strides = self.fix_shapes_before_adding_from_nested(sdfg,new_sdfg,local_shape,local_strides) reshape_viewname, reshape_view = sdfg.add_view(sdfg_name + "_view_reshape_" + str(self.views), local_shape, dtype, @@ -1659,15 +1659,15 @@ def add_full_object(self, new_sdfg: SDFG,sdfg:SDFG, array: dat.Array, local_name rv = substate.add_read(reshape_viewname) w = substate.add_write(sdfg_name) substate.add_edge(rv, 'views', w, None, dpcp(memlet)) - local_shape,local_strides = self.fix_shapes_before_adding_nested(sdfg,new_sdfg,local_shape,local_strides) + local_shape,local_strides = self.fix_shapes_before_adding_nested(sdfg,new_sdfg,local_shape,local_strides) new_sdfg.add_array(self.name_mapping[new_sdfg][local_name.name], local_shape, dtype, array.storage, strides=local_strides, - offset=local_offsets) + offset=local_offsets) - return True, (wv, rv) + return True, (wv, rv) if new_sdfg.arrays.get(self.name_mapping[new_sdfg][local_name.name]) is None: if shape == []: new_sdfg.add_scalar(self.name_mapping[new_sdfg][local_name.name], array.dtype, @@ -1681,7 +1681,7 @@ def add_full_object(self, new_sdfg: SDFG,sdfg:SDFG, array: dat.Array, local_name array.storage, strides=strides, offset=offset) - return False, None + return False, None else: #raise warning that array already exists in sdfg print(f"Array {self.name_mapping[new_sdfg][local_name.name]} already exists in SDFG {new_sdfg.name}") @@ -1698,7 +1698,7 @@ def add_simple_array_to_element_view_pair_in_tower(self, sdfg: SDFG, array: dat. shape=[1] offsets_zero=[0] strides=[1] - shape,strides = self.fix_shapes_before_adding(sdfg,shape,strides) + shape,strides = self.fix_shapes_before_adding(sdfg,shape,strides) viewname, view = sdfg.add_view(view_name, shape, array.dtype, @@ -1778,11 +1778,11 @@ def get_local_shape(self, sdfg:SDFG,local_definition:ast_internal_classes.Var_De offsets.append(offset_value) if len(sizes)==0: return [1],[0],[0],[1] - strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] + strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] return sizes, offsets, actual_offsets,strides else: return [1],[0],[0],[1] - + def process_variable_call(self, variable_in_calling_context: ast_internal_classes.FNode, local_name:ast_internal_classes.FNode, sdfg: SDFG, new_sdfg: SDFG, substate:SDFGState, read:bool,write:bool,local_definition:ast_internal_classes.Var_Decl_Node): # We need to first check and have separate handling for: @@ -1835,7 +1835,7 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe shape=[1] offsets_zero=[0] strides=[1] - shape,strides = self.fix_shapes_before_adding(sdfg,shape,strides) + shape,strides = self.fix_shapes_before_adding(sdfg,shape,strides) viewname, view = sdfg.add_view(sdfg_name + "_view_" + str(self.views), shape, array.dtype, @@ -1857,7 +1857,7 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe self.views = self.views + 1 is_scalar=(len(shape)==0) or (len(shape)==1 and shape[0]==1) is_local_scalar=(len(local_shape)==0) or (len(local_shape)==1 and local_shape[0]==1) - + if local_shape!=shape and (not(is_scalar and is_local_scalar)): #we must add an extra view reshaping the access to the local shape if len(shape)==len(local_shape): @@ -1872,10 +1872,10 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe local_offsets[i]=offsets[i] recompute_strides=True if recompute_strides: - local_strides = [dat._prod(local_shape[:i]) for i in range(len(local_shape))] + local_strides = [dat._prod(local_shape[:i]) for i in range(len(local_shape))] + - - else: + else: if len(local_shape)!=1: raise NotImplementedError("Local shape not 1") local_shape,local_strides = self.fix_shapes_before_adding_from_nested(sdfg,new_sdfg,local_shape,local_strides) @@ -1885,8 +1885,8 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe storage=array.storage, strides=local_strides, offset=local_offsets) - - + + memlet=Memlet.from_array(viewname, sdfg.arrays[viewname]) if write: res_v_read = substate.add_read(reshape_viewname) @@ -1896,7 +1896,7 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe res_v_write = substate.add_write(reshape_viewname) substate.add_edge(wv, None, res_v_write, None, dpcp(memlet)) wv=res_v_write - + local_shape,local_strides = self.fix_shapes_before_adding_nested(sdfg,new_sdfg,local_shape,local_strides) new_sdfg.add_array(self.name_mapping[new_sdfg][local_name.name], @@ -1949,7 +1949,7 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe if views_needed: return True, [sdfg_name, views[0], views[1], variable_in_calling_context] else: - + return True, [sdfg_name,last_read, last_written, variable_in_calling_context] elif isinstance(member,ast_internal_classes.Array_Subscript_Node): @@ -1971,7 +1971,7 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe if views_needed: return True, [sdfg_name, views[0], views[1], variable_in_calling_context] else: - + return True, [sdfg_name,last_read, last_written, variable_in_calling_context] else: @@ -1981,14 +1981,14 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe #this is a simple array, but must still have first view to Array and then to subset. last_read, last_written=self.add_basic_view_pair_in_tower(sdfg,array,name_chain,member,substate,last_read,last_written,read,write) last_read, last_written=self.add_simple_array_to_element_view_pair_in_tower(sdfg,array,name_chain,member,substate,last_read,last_written,read,write,shape,offsets,strides,subset) - + if len(shape)==0: shape=[1] offsets=[0] strides=[1] is_scalar=(len(shape)==0) or (len(shape)==1 and shape[0]==1) is_local_scalar=(len(local_shape)==0) or (len(local_shape)==1 and local_shape[0]==1) - if local_shape!=shape and (not(is_scalar and is_local_scalar)): + if local_shape!=shape and (not(is_scalar and is_local_scalar)): if len(shape)==len(local_shape): print("Shapes are not equal, but the same size. We hope that the symbolic sizes evaluate to the same values") #this is not necessary, as here we use the outside sizes for some reason??? @@ -2000,10 +2000,10 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe # local_offsets[i]=offsets[i] # recompute_strides=True # if recompute_strides: - # local_strides = [dat._prod(local_shape[:i]) for i in range(len(local_shape))] - else: - raise NotImplementedError("Local shape not the same as outside shape") - shape,strides = self.fix_shapes_before_adding_nested(sdfg,new_sdfg,shape,strides) + # local_strides = [dat._prod(local_shape[:i]) for i in range(len(local_shape))] + else: + raise NotImplementedError("Local shape not the same as outside shape") + shape,strides = self.fix_shapes_before_adding_nested(sdfg,new_sdfg,shape,strides) new_sdfg.add_array(self.name_mapping[new_sdfg][local_name.name], shape, array.dtype, @@ -2455,7 +2455,7 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg # sdfg.add_datadesc(self.name_mapping[sdfg][node.name], arr_dtype) else: - sizes,strides = self.fix_shapes_before_adding(sdfg,sizes,strides) + sizes,strides = self.fix_shapes_before_adding(sdfg,sizes,strides) # print("Adding local array",self.name_mapping[sdfg][node.name],sizes,datatype,offset,strides,transient) sdfg.add_array(self.name_mapping[sdfg][node.name], @@ -2491,7 +2491,7 @@ def fix_shapes_before_adding_nested(self, sdfg: SDFG,new_sdfg,sizes:List,strides if not hasattr(i,"free_symbols"): continue free_symbols=i.free_symbols - + for s in free_symbols: if new_sdfg.symbols.get(s.name) is not None: #self.temporary_sym_dict[new_sdfg.name]["sym_"+s.name]=["sym_"+s.name] @@ -2511,10 +2511,10 @@ def fix_shapes_before_adding_nested(self, sdfg: SDFG,new_sdfg,sizes:List,strides self.temporary_sym_dict[new_sdfg.name]["sym_"+s.name]=name_in_parent found=True if not found: - raise ValueError(f"Temporary symbol not found for {s.name}") + raise ValueError(f"Temporary symbol not found for {s.name}") else: - - + + if sdfg.symbols.get(s.name) is not None: self.temporary_sym_dict[new_sdfg.name][s.name]=s.name new_sdfg.add_symbol(s.name, sdfg.symbols[s.name].dtype) @@ -2523,25 +2523,25 @@ def fix_shapes_before_adding_nested(self, sdfg: SDFG,new_sdfg,sizes:List,strides self.temporary_sym_dict[new_sdfg.name]["sym_"+s.name]=s.name i=i.subs(s,sym.symbol("sym_"+s.name)) else: - print(f"Symbol {s.name} not found in arrays") + print(f"Symbol {s.name} not found in arrays") raise ValueError(f"Symbol {s.name} not found in arrays") - + sizes= list(sizes) sizes[idx]=i sizes=tuple(sizes) #this means it is an input, so we can try adding it to the symbols mapping if changed: - strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] - return sizes,strides - + strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] + return sizes,strides + def fix_shapes_before_adding_from_nested(self, sdfg: SDFG,new_sdfg,sizes:List,strides:List): changed=False for idx, i in enumerate(sizes): if not hasattr(i,"free_symbols"): continue free_symbols=i.free_symbols - + for s in free_symbols: if sdfg.symbols.get(s.name) is not None: #self.temporary_sym_dict[new_sdfg.name]["sym_"+s.name]=["sym_"+s.name] @@ -2563,16 +2563,16 @@ def fix_shapes_before_adding_from_nested(self, sdfg: SDFG,new_sdfg,sizes:List,st print("here") else: - print(f"Symbol {s.name} not found in arrays") - + print(f"Symbol {s.name} not found in arrays") + sizes= list(sizes) sizes[idx]=i sizes=tuple(sizes) #this means it is an input, so we can try adding it to the symbols mapping if changed: - strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] - return sizes,strides + strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] + return sizes,strides def fix_shapes_before_adding(self, sdfg: SDFG,sizes:List,strides:List): changed=False @@ -2580,7 +2580,7 @@ def fix_shapes_before_adding(self, sdfg: SDFG,sizes:List,strides:List): if not hasattr(i,"free_symbols"): continue free_symbols=i.free_symbols - + for s in free_symbols: if sdfg.symbols.get(s.name) is not None: pass @@ -2601,18 +2601,18 @@ def fix_shapes_before_adding(self, sdfg: SDFG,sizes:List,strides:List): self.temporary_sym_dict[sdfg.name]["sym_"+s.name]=name_in_parent found=True if not found: - raise ValueError(f"Temporary symbol not found for {s.name}") + raise ValueError(f"Temporary symbol not found for {s.name}") else: - print(f"Symbol {s.name} not found in arrays") - + print(f"Symbol {s.name} not found in arrays") + sizes= list(sizes) sizes[idx]=i sizes=tuple(sizes) #this means it is an input, so we can try adding it to the symbols mapping if changed: - strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] - return sizes,strides + strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] + return sizes,strides def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG, cfg: ControlFlowRegion): break_block = BreakBlock(f'Break_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}') @@ -3046,7 +3046,7 @@ def create_sdfg_from_string( own_ast, program = create_internal_ast(cfg) cfg = SDFGConfig( - {sdfg_name: f"{sdfg_name}_function"}, + {sdfg_name: f"{sdfg_name}_function"}, config_injections=None, normalize_offsets=normalize_offsets, multiple_sdfgs=False @@ -3187,6 +3187,12 @@ def create_sdfg_from_fortran_file_with_options( else: ast = correct_for_function_calls(ast) + import os + FORTRAN_AST_PATH = os.getenv("FORTRAN_AST_PATH") + if FORTRAN_AST_PATH is not None: + with open(FORTRAN_AST_PATH, 'w') as f: + f.write(ast.tofortran()) + dep_graph = compute_dep_graph(ast, 'radiation_interface') parse_order = list(reversed(list(nx.topological_sort(dep_graph)))) @@ -3254,7 +3260,7 @@ def create_sdfg_from_fortran_file_with_options( struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) struct_deps_finder.visit(i) struct_deps = struct_deps_finder.structs_used - + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, struct_deps_finder.pointer_names): if j not in struct_dep_graph.nodes: @@ -3265,7 +3271,7 @@ def create_sdfg_from_fortran_file_with_options( program.structures = ast_transforms.Structures(structs_lister.structs) program = run_ast_transformations(partial_ast, program, cfg, True) - + # functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() # functions_and_subroutines_builder.visit(program) @@ -3345,7 +3351,7 @@ def create_sdfg_from_fortran_file_with_options( # program = ast_transforms.ArgumentExtractor(program).visit(program) # program = ast_transforms.ElementalFunctionExpander( # functions_and_subroutines_builder.names, ast=program).visit(program) - + # program = ast_transforms.ForDeclarer().visit(program) # program = ast_transforms.PointerRemoval().visit(program) # program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)