From 67e05806f9a28e1b3fd485e83ef049e27b584575 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Thu, 9 Nov 2023 23:42:45 +0530 Subject: [PATCH] wip --- integration_tests/symbolics_01.py | 6 + src/bin/lpython.cpp | 73 +- src/libasr/ASR.asdl | 4 +- src/libasr/CMakeLists.txt | 2 + src/libasr/asdl_cpp.py | 5 +- src/libasr/asr_utils.cpp | 207 +- src/libasr/asr_utils.h | 150 +- src/libasr/asr_verify.cpp | 33 +- src/libasr/codegen/KaleidoscopeJIT.h | 8 +- src/libasr/codegen/asr_to_c.cpp | 65 +- src/libasr/codegen/asr_to_c_cpp.h | 54 +- src/libasr/codegen/asr_to_cpp.cpp | 6 +- src/libasr/codegen/asr_to_fortran.cpp | 1875 +++++++++++++++++ src/libasr/codegen/asr_to_fortran.h | 15 + src/libasr/codegen/asr_to_llvm.cpp | 581 +++-- src/libasr/codegen/asr_to_wasm.cpp | 19 +- src/libasr/codegen/evaluator.cpp | 2 +- src/libasr/codegen/llvm_array_utils.cpp | 45 +- src/libasr/codegen/llvm_array_utils.h | 14 +- src/libasr/codegen/llvm_utils.cpp | 32 +- src/libasr/compiler_tester/tester.py | 6 + src/libasr/dwarf_convert.py | 18 +- src/libasr/gen_pass.py | 1 + src/libasr/pass/array_op.cpp | 391 +++- src/libasr/pass/implied_do_loops.cpp | 204 +- src/libasr/pass/inline_function_calls.cpp | 273 ++- src/libasr/pass/insert_deallocate.cpp | 67 + src/libasr/pass/insert_deallocate.h | 14 + src/libasr/pass/instantiate_template.cpp | 237 ++- src/libasr/pass/instantiate_template.h | 21 +- .../pass/intrinsic_array_function_registry.h | 23 +- src/libasr/pass/intrinsic_function.cpp | 10 +- src/libasr/pass/intrinsic_function_registry.h | 405 +++- src/libasr/pass/nested_vars.cpp | 25 + src/libasr/pass/pass_array_by_data.cpp | 92 +- src/libasr/pass/pass_manager.h | 106 +- src/libasr/pass/pass_utils.cpp | 42 +- src/libasr/pass/pass_utils.h | 41 +- src/libasr/pass/replace_symbolic.h | 2 +- src/libasr/pass/subroutine_from_function.cpp | 56 +- .../transform_optional_argument_functions.cpp | 2 +- src/libasr/pass/where.cpp | 8 + src/libasr/pickle.cpp | 8 +- src/libasr/pickle.h | 4 +- src/libasr/runtime/lfortran_intrinsics.c | 443 ++-- src/libasr/runtime/lfortran_intrinsics.h | 3 + src/libasr/string_utils.cpp | 16 + src/libasr/string_utils.h | 1 + src/libasr/utils.h | 79 +- src/libasr/utils2.cpp | 180 ++ src/lpython/python_evaluator.cpp | 2 +- src/lpython/semantics/python_ast_to_asr.cpp | 8 +- 52 files changed, 5180 insertions(+), 804 deletions(-) create mode 100644 src/libasr/codegen/asr_to_fortran.cpp create mode 100644 src/libasr/codegen/asr_to_fortran.h create mode 100644 src/libasr/pass/insert_deallocate.cpp create mode 100644 src/libasr/pass/insert_deallocate.h diff --git a/integration_tests/symbolics_01.py b/integration_tests/symbolics_01.py index 3187a9005f7..ae0dab1dd42 100644 --- a/integration_tests/symbolics_01.py +++ b/integration_tests/symbolics_01.py @@ -13,4 +13,10 @@ def main0(): assert(z == pi + y) assert(z != S(2)*pi + y) + # testing PR 2404 + p: S = Symbol('pi') + print(p) + print(p != pi) + assert(p != pi) + main0() \ No newline at end of file diff --git a/src/bin/lpython.cpp b/src/bin/lpython.cpp index bdc2b92aafa..36eb8c43e44 100644 --- a/src/bin/lpython.cpp +++ b/src/bin/lpython.cpp @@ -80,29 +80,6 @@ std::string get_kokkos_dir() throw LCompilers::LCompilersException("LFORTRAN_KOKKOS_DIR is not defined"); } -int visualize_json(std::string &astr_data_json, LCompilers::Platform os) { - using namespace LCompilers; - std::string file_loc = LCompilers::LPython::generate_visualize_html(astr_data_json); - std::string open_cmd = ""; - switch (os) { - case Linux: open_cmd = "xdg-open"; break; - case Windows: open_cmd = "start"; break; - case macOS_Intel: - case macOS_ARM: open_cmd = "open"; break; - default: - std::cerr << "Unsupported Platform " << pf2s(os) <(*this); }", 1) self.emit("public:") self.emit( "std::string s, indtd = \"\";", 1) + self.emit( "bool no_loc = false;", 1) self.emit( "int indent_level = 0, indent_spaces = 4;", 1) # Storing a reference to LocationManager like this isn't ideal. # One must make sure JsonBaseVisitor isn't reused in a case where AST/ASR has changed @@ -1739,7 +1740,9 @@ def visitModule(self, mod): self.emit( "indtd = std::string(indent_level*indent_spaces, ' ');",2) self.emit( "}",1) self.emit( "void append_location(std::string &s, uint32_t first, uint32_t last) {", 1) - self.emit( 's.append("\\"loc\\": {");', 2); + self.emit( 'if (no_loc) return;', 2) + self.emit( 's.append(",\\n" + indtd);', 2) + self.emit( 's.append("\\"loc\\": {");', 2) self.emit( 'inc_indent();', 2) self.emit( 's.append("\\n" + indtd);', 2) self.emit( 's.append("\\"first\\": " + std::to_string(first));', 2) diff --git a/src/libasr/asr_utils.cpp b/src/libasr/asr_utils.cpp index cb02443863a..3e393875a4a 100644 --- a/src/libasr/asr_utils.cpp +++ b/src/libasr/asr_utils.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace LCompilers { @@ -133,7 +134,8 @@ void extract_module_python(const ASR::TranslationUnit_t &m, } } -void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_interface) { +void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_interface, + std::map changed_external_function_symbol) { /* Iterate over body of program, check if there are any subroutine calls if yes, iterate over its args and update the args if they are equal to the old symbol @@ -146,11 +148,34 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i This function updates `sub2` to use the new symbol `c` that is now a function, not a variable. Along with this, it also updates the args of `sub2` to use the new symbol `c` instead of the old one. */ - class UpdateArgsVisitor : public PassUtils::PassVisitor + + class ArgsReplacer : public ASR::BaseExprReplacer { + public: + Allocator &al; + ASR::symbol_t* new_sym; + + ArgsReplacer(Allocator &al_) : al(al_) {} + + void replace_Var(ASR::Var_t* x) { + *current_expr = ASRUtils::EXPR(ASR::make_Var_t(al, x->base.base.loc, new_sym)); + } + }; + + class ArgsVisitor : public ASR::CallReplacerOnExpressionsVisitor { public: + Allocator &al; SymbolTable* scope = current_scope; - UpdateArgsVisitor(Allocator &al) : PassVisitor(al, nullptr) {} + ArgsReplacer replacer; + std::map &changed_external_function_symbol; + ArgsVisitor(Allocator &al_, std::map &changed_external_function_symbol_) : al(al_), replacer(al_), + changed_external_function_symbol(changed_external_function_symbol_) {} + + void call_replacer_(ASR::symbol_t* new_sym) { + replacer.current_expr = current_expr; + replacer.new_sym = new_sym; + replacer.replace_expr(*current_expr); + } ASR::symbol_t* fetch_sym(ASR::symbol_t* arg_sym_underlying) { ASR::symbol_t* sym = nullptr; @@ -165,26 +190,59 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i } return sym; } + + void handle_Var(ASR::expr_t* arg_expr, ASR::expr_t** expr_to_replace) { + if (ASR::is_a(*arg_expr)) { + ASR::Var_t* arg_var = ASR::down_cast(arg_expr); + ASR::symbol_t* arg_sym = arg_var->m_v; + ASR::symbol_t* arg_sym_underlying = ASRUtils::symbol_get_past_external(arg_sym); + ASR::symbol_t* sym = fetch_sym(arg_sym_underlying); + if (sym != arg_sym) { + ASR::expr_t** current_expr_copy = current_expr; + current_expr = const_cast((expr_to_replace)); + this->call_replacer_(sym); + current_expr = current_expr_copy; + } + } + } + void visit_SubroutineCall(const ASR::SubroutineCall_t& x) { ASR::SubroutineCall_t* subrout_call = (ASR::SubroutineCall_t*)(&x); for (size_t j = 0; j < subrout_call->n_args; j++) { ASR::call_arg_t arg = subrout_call->m_args[j]; ASR::expr_t* arg_expr = arg.m_value; - if (ASR::is_a(*arg_expr)) { - ASR::Var_t* arg_var = ASR::down_cast(arg_expr); - ASR::symbol_t* arg_sym = arg_var->m_v; - ASR::symbol_t* arg_sym_underlying = ASRUtils::symbol_get_past_external(arg_sym); - ASR::symbol_t* sym = fetch_sym(arg_sym_underlying); - if (sym != arg_sym) { - subrout_call->m_args[j].m_value = ASRUtils::EXPR(ASR::make_Var_t(al, arg_expr->base.loc, sym)); - } - } + handle_Var(arg_expr, &(subrout_call->m_args[j].m_value)); + } + } + + void visit_FunctionCall(const ASR::FunctionCall_t& x) { + ASR::FunctionCall_t* func_call = (ASR::FunctionCall_t*)(&x); + for (size_t j = 0; j < func_call->n_args; j++) { + ASR::call_arg_t arg = func_call->m_args[j]; + ASR::expr_t* arg_expr = arg.m_value; + handle_Var(arg_expr, &(func_call->m_args[j].m_value)); } } void visit_Function(const ASR::Function_t& x) { ASR::Function_t* func = (ASR::Function_t*)(&x); + scope = func->m_symtab; + ASRUtils::SymbolDuplicator symbol_duplicator(al); + std::map scope_ = scope->get_scope(); + std::vector symbols_to_duplicate; + for (auto it: scope_) { + if (changed_external_function_symbol.find(it.first) != changed_external_function_symbol.end() && + is_external_sym_changed(it.second, changed_external_function_symbol[it.first])) { + symbols_to_duplicate.push_back(it.first); + } + } + + for (auto it: symbols_to_duplicate) { + scope->erase_symbol(it); + symbol_duplicator.duplicate_symbol(changed_external_function_symbol[it], scope); + } + for (size_t i = 0; i < func->n_args; i++) { ASR::expr_t* arg_expr = func->m_args[i]; if (ASR::is_a(*arg_expr)) { @@ -193,7 +251,10 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i ASR::symbol_t* arg_sym_underlying = ASRUtils::symbol_get_past_external(arg_sym); ASR::symbol_t* sym = fetch_sym(arg_sym_underlying); if (sym != arg_sym) { - func->m_args[i] = ASRUtils::EXPR(ASR::make_Var_t(al, arg_expr->base.loc, sym)); + ASR::expr_t** current_expr_copy = current_expr; + current_expr = const_cast(&(func->m_args[i])); + this->call_replacer_(sym); + current_expr = current_expr_copy; } } } @@ -210,7 +271,7 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i }; if (implicit_interface) { - UpdateArgsVisitor v(al); + ArgsVisitor v(al, changed_external_function_symbol); SymbolTable *tu_symtab = ASRUtils::get_tu_symtab(current_scope); ASR::asr_t* asr_ = tu_symtab->asr_owner; ASR::TranslationUnit_t* tu = ASR::down_cast2(asr_); @@ -618,7 +679,7 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right, } if (ASRUtils::symbol_parent_symtab(a_name)->get_counter() != curr_scope->get_counter()) { ADD_ASR_DEPENDENCIES_WITH_NAME(curr_scope, a_name, current_function_dependencies, s2c(al, matched_func_name)); - } + } ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies); ASRUtils::set_absent_optional_arguments_to_null(a_args, func, al); asr = ASRUtils::make_FunctionCall_t_util(al, loc, a_name, sym, @@ -703,7 +764,7 @@ void process_overloaded_unary_minus_function(ASR::symbol_t* proc, ASR::expr_t* o } if (ASRUtils::symbol_parent_symtab(a_name)->get_counter() != curr_scope->get_counter()) { ADD_ASR_DEPENDENCIES_WITH_NAME(curr_scope, a_name, current_function_dependencies, s2c(al, matched_func_name)); - } + } ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies); ASRUtils::set_absent_optional_arguments_to_null(a_args, func, al); asr = ASRUtils::make_FunctionCall_t_util(al, loc, a_name, proc, @@ -1355,6 +1416,7 @@ ASR::symbol_t* import_class_procedure(Allocator &al, const Location& loc, ASR::asr_t* make_Binop_util(Allocator &al, const Location& loc, ASR::binopType binop, ASR::expr_t* lexpr, ASR::expr_t* rexpr, ASR::ttype_t* ttype) { + ASRUtils::make_ArrayBroadcast_t_util(al, loc, lexpr, rexpr); switch (ttype->type) { case ASR::ttypeType::Real: { return ASR::make_RealBinOp_t(al, loc, lexpr, binop, rexpr, @@ -1394,6 +1456,119 @@ ASR::asr_t* make_Cmpop_util(Allocator &al, const Location& loc, ASR::cmpopType c } } +void make_ArrayBroadcast_t_util(Allocator& al, const Location& loc, + ASR::expr_t*& expr1, ASR::expr_t*& expr2, ASR::dimension_t* expr1_mdims, + size_t expr1_ndims) { + ASR::ttype_t* expr1_type = ASRUtils::expr_type(expr1); + Vec shape_args; + shape_args.reserve(al, 1); + shape_args.push_back(al, expr1); + + Vec dims; + dims.reserve(al, 1); + ASR::dimension_t dim; + dim.loc = loc; + dim.m_length = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, + expr1_ndims, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))); + dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, + 1, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))); + dims.push_back(al, dim); + ASR::ttype_t* dest_shape_type = ASRUtils::TYPE(ASR::make_Array_t(al, loc, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), dims.p, dims.size(), + ASR::array_physical_typeType::FixedSizeArray)); + + ASR::expr_t* dest_shape = nullptr; + ASR::expr_t* value = nullptr; + ASR::ttype_t* ret_type = nullptr; + if( ASRUtils::is_fixed_size_array(expr1_mdims, expr1_ndims) ) { + Vec lengths; lengths.reserve(al, expr1_ndims); + for( size_t i = 0; i < expr1_ndims; i++ ) { + lengths.push_back(al, ASRUtils::expr_value(expr1_mdims[i].m_length)); + } + dest_shape = ASRUtils::EXPR(ASR::make_ArrayConstant_t(al, loc, + lengths.p, lengths.size(), dest_shape_type, ASR::arraystorageType::ColMajor)); + Vec dims; + dims.reserve(al, 1); + ASR::dimension_t dim; + dim.loc = loc; + dim.m_length = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, + ASRUtils::get_fixed_size_of_array(expr1_mdims, expr1_ndims), + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))); + dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, + 1, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))); + dims.push_back(al, dim); + + if( ASRUtils::is_value_constant(expr2) && + ASRUtils::get_fixed_size_of_array(expr1_mdims, expr1_ndims) <= 256 ) { + ASR::ttype_t* value_type = ASRUtils::TYPE(ASR::make_Array_t(al, loc, + ASRUtils::type_get_past_array(ASRUtils::expr_type(expr2)), dims.p, dims.size(), + ASR::array_physical_typeType::FixedSizeArray)); + Vec values; + values.reserve(al, ASRUtils::get_fixed_size_of_array(expr1_mdims, expr1_ndims)); + for( int64_t i = 0; i < ASRUtils::get_fixed_size_of_array(expr1_mdims, expr1_ndims); i++ ) { + values.push_back(al, expr2); + } + value = ASRUtils::EXPR(ASR::make_ArrayConstant_t(al, loc, + values.p, values.size(), value_type, ASR::arraystorageType::ColMajor)); + ret_type = value_type; + } + } else { + dest_shape = ASRUtils::EXPR(ASR::make_IntrinsicArrayFunction_t(al, loc, + static_cast(ASRUtils::IntrinsicArrayFunctions::Shape), shape_args.p, + shape_args.size(), 0, dest_shape_type, nullptr)); + } + + if (ret_type == nullptr) { + // TODO: Construct appropriate return type here + // For now simply coping the type from expr1 + ret_type = expr1_type; + } + expr2 = ASRUtils::EXPR(ASR::make_ArrayBroadcast_t(al, loc, expr2, dest_shape, ret_type, value)); + + if (ASRUtils::extract_physical_type(expr1_type) != ASRUtils::extract_physical_type(ret_type)) { + expr2 = ASRUtils::EXPR(ASRUtils::make_ArrayPhysicalCast_t_util(al, loc, expr2, + ASRUtils::extract_physical_type(ret_type), + ASRUtils::extract_physical_type(expr1_type), expr1_type, nullptr)); + } +} + +void make_ArrayBroadcast_t_util(Allocator& al, const Location& loc, + ASR::expr_t*& expr1, ASR::expr_t*& expr2) { + ASR::ttype_t* expr1_type = ASRUtils::expr_type(expr1); + ASR::ttype_t* expr2_type = ASRUtils::expr_type(expr2); + ASR::dimension_t *expr1_mdims = nullptr, *expr2_mdims = nullptr; + size_t expr1_ndims = ASRUtils::extract_dimensions_from_ttype(expr1_type, expr1_mdims); + size_t expr2_ndims = ASRUtils::extract_dimensions_from_ttype(expr2_type, expr2_mdims); + if( expr1_ndims == expr2_ndims ) { + // TODO: Always broadcast both the expressions + return ; + } + + if( expr1_ndims > expr2_ndims ) { + if( ASR::is_a(*expr2) ) { + return ; + } + make_ArrayBroadcast_t_util(al, loc, expr1, expr2, expr1_mdims, expr1_ndims); + } else { + if( ASR::is_a(*expr1) ) { + return ; + } + make_ArrayBroadcast_t_util(al, loc, expr2, expr1, expr2_mdims, expr2_ndims); + } +} + +int64_t compute_trailing_zeros(int64_t number) { + int64_t trailing_zeros = 0; + if (number == 0) { + return 32; + } + while (number % 2 == 0) { + number = number / 2; + trailing_zeros++; + } + return trailing_zeros; +} + //Initialize pointer to zero so that it can be initialized in first call to get_instance ASRUtils::LabelGenerator* ASRUtils::LabelGenerator::label_generator = nullptr; diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 06e42e7fd9d..3a6092ecdcc 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -20,9 +20,15 @@ } \ SymbolTable* temp_scope = current_scope; \ if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(final_sym)->get_counter() && \ - !ASR::is_a(*asr_owner_sym) && !ASR::is_a(*final_sym) && \ - !ASR::is_a(*final_sym)) { \ - current_function_dependencies.push_back(al, ASRUtils::symbol_name(final_sym)); \ + !ASR::is_a(*final_sym) && !ASR::is_a(*final_sym)) { \ + if (ASR::is_a(*asr_owner_sym) || ASR::is_a(*asr_owner_sym)) { \ + temp_scope = temp_scope->parent; \ + if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(final_sym)->get_counter()) { \ + current_function_dependencies.push_back(al, ASRUtils::symbol_name(final_sym)); \ + } \ + } else { \ + current_function_dependencies.push_back(al, ASRUtils::symbol_name(final_sym)); \ + } \ } \ #define ADD_ASR_DEPENDENCIES_WITH_NAME(current_scope, final_sym, current_function_dependencies, dep_name) ASR::symbol_t* asr_owner_sym = nullptr; \ @@ -31,9 +37,15 @@ } \ SymbolTable* temp_scope = current_scope; \ if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(final_sym)->get_counter() && \ - !ASR::is_a(*asr_owner_sym) && !ASR::is_a(*final_sym) && \ - !ASR::is_a(*final_sym)) { \ - current_function_dependencies.push_back(al, dep_name); \ + !ASR::is_a(*final_sym) && !ASR::is_a(*final_sym)) { \ + if (ASR::is_a(*asr_owner_sym) || ASR::is_a(*asr_owner_sym)) { \ + temp_scope = temp_scope->parent; \ + if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(final_sym)->get_counter()) { \ + current_function_dependencies.push_back(al, dep_name); \ + } \ + } else { \ + current_function_dependencies.push_back(al, dep_name); \ + } \ } \ namespace LCompilers { @@ -1085,6 +1097,12 @@ static inline bool all_args_evaluated(const Vec &args) { static inline bool extract_value(ASR::expr_t* value_expr, std::complex& value) { + if ( ASR::is_a(*value_expr) ) { + value_expr = ASR::down_cast(value_expr)->m_value; + if (!value_expr) { + return false; + } + } if( !ASR::is_a(*value_expr) ) { return false; } @@ -1700,7 +1718,20 @@ void extract_module_python(const ASR::TranslationUnit_t &m, std::vector>& children_modules, std::string module_name); -void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_interface); +static inline bool is_external_sym_changed(ASR::symbol_t* original_sym, ASR::symbol_t* external_sym) { + if (!ASR::is_a(*original_sym) || !ASR::is_a(*external_sym)) { + return false; + } + ASR::Function_t* original_func = ASR::down_cast(original_sym); + ASR::Function_t* external_func = ASR::down_cast(external_sym); + bool same_number_of_args = original_func->n_args == external_func->n_args; + // TODO: Check if the arguments are the same + return !(same_number_of_args); +} + +void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_interface, + std::map changed_external_function_symbol); + ASR::Module_t* extract_module(const ASR::TranslationUnit_t &m); @@ -2055,30 +2086,64 @@ static inline ASR::asr_t* make_ArraySize_t_util( Allocator &al, const Location &a_loc, ASR::expr_t* a_v, ASR::expr_t* a_dim, ASR::ttype_t* a_type, ASR::expr_t* a_value, bool for_type=true) { + int dim = -1; + bool is_dimension_constant = (a_dim != nullptr) && ASRUtils::extract_value( + ASRUtils::expr_value(a_dim), dim); if( ASR::is_a(*a_v) ) { a_v = ASR::down_cast(a_v)->m_arg; } - ASR::dimension_t* m_dims = nullptr; - size_t n_dims = ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(a_v), m_dims); - bool is_dimension_dependent_only_on_arguments_ = is_dimension_dependent_only_on_arguments(m_dims, n_dims); - int dim = -1; - bool is_dimension_constant = (a_dim != nullptr) && ASRUtils::extract_value(ASRUtils::expr_value(a_dim), dim); - - bool compute_size = (is_dimension_dependent_only_on_arguments_ && - (is_dimension_constant || a_dim == nullptr)); - if( compute_size && for_type ) { - ASR::dimension_t* m_dims = nullptr; - size_t n_dims = ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(a_v), m_dims); + if( ASR::is_a(*a_v) ) { + ASR::ArraySection_t* array_section_t = ASR::down_cast(a_v); if( a_dim == nullptr ) { - ASR::asr_t* size = ASR::make_IntegerConstant_t(al, a_loc, 1, a_type); - for( size_t i = 0; i < n_dims; i++ ) { + ASR::asr_t* const1 = ASR::make_IntegerConstant_t(al, a_loc, 1, a_type); + ASR::asr_t* size = const1; + for( size_t i = 0; i < array_section_t->n_args; i++ ) { + ASR::expr_t* start = array_section_t->m_args[i].m_left; + ASR::expr_t* end = array_section_t->m_args[i].m_right; + ASR::expr_t* d = array_section_t->m_args[i].m_step; + ASR::expr_t* endminusstart = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, a_loc, end, ASR::binopType::Sub, start, a_type, nullptr)); + ASR::expr_t* byd = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, a_loc, endminusstart, ASR::binopType::Div, d, a_type, nullptr)); + ASR::expr_t* plus1 = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, a_loc, byd, ASR::binopType::Add, ASRUtils::EXPR(const1), a_type, nullptr)); size = ASR::make_IntegerBinOp_t(al, a_loc, ASRUtils::EXPR(size), - ASR::binopType::Mul, m_dims[i].m_length, a_type, nullptr); + ASR::binopType::Mul, plus1, a_type, nullptr); } return size; } else if( is_dimension_constant ) { - return (ASR::asr_t*) m_dims[dim - 1].m_length; + ASR::asr_t* const1 = ASR::make_IntegerConstant_t(al, a_loc, 1, a_type); + ASR::expr_t* start = array_section_t->m_args[dim - 1].m_left; + ASR::expr_t* end = array_section_t->m_args[dim - 1].m_right; + ASR::expr_t* d = array_section_t->m_args[dim - 1].m_step; + ASR::expr_t* endminusstart = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, a_loc, end, ASR::binopType::Sub, start, a_type, nullptr)); + ASR::expr_t* byd = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, a_loc, endminusstart, ASR::binopType::Div, d, a_type, nullptr)); + return ASR::make_IntegerBinOp_t(al, a_loc, byd, ASR::binopType::Add, + ASRUtils::EXPR(const1), a_type, nullptr); + } + } else { + ASR::dimension_t* m_dims = nullptr; + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(a_v), m_dims); + bool is_dimension_dependent_only_on_arguments_ = is_dimension_dependent_only_on_arguments(m_dims, n_dims); + + bool compute_size = (is_dimension_dependent_only_on_arguments_ && + (is_dimension_constant || a_dim == nullptr)); + if( compute_size && for_type ) { + ASR::dimension_t* m_dims = nullptr; + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(a_v), m_dims); + if( a_dim == nullptr ) { + ASR::asr_t* size = ASR::make_IntegerConstant_t(al, a_loc, 1, a_type); + for( size_t i = 0; i < n_dims; i++ ) { + size = ASR::make_IntegerBinOp_t(al, a_loc, ASRUtils::EXPR(size), + ASR::binopType::Mul, m_dims[i].m_length, a_type, nullptr); + } + return size; + } else if( is_dimension_constant ) { + return (ASR::asr_t*) m_dims[dim - 1].m_length; + } } } @@ -2412,6 +2477,14 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR } } +inline std::string remove_trailing_white_spaces(std::string str) { + int end = str.size() - 1; + while (end >= 0 && str[end] == ' ') { + end--; + } + return str.substr(0, end + 1); +} + inline bool is_same_type_pointer(ASR::ttype_t* source, ASR::ttype_t* dest) { bool is_source_pointer = is_pointer(source), is_dest_pointer = is_pointer(dest); if( (!is_source_pointer && !is_dest_pointer) || @@ -3250,6 +3323,14 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer { } } + void replace_ArraySize(ASR::ArraySize_t* x) { + ASR::BaseExprReplacer::replace_ArraySize(x); + if( ASR::is_a(*x->m_v) ) { + *current_expr = ASRUtils::EXPR(ASRUtils::make_ArraySize_t_util( + al, x->base.base.loc, x->m_v, x->m_dim, x->m_type, x->m_value, true)); + } + } + }; // Finds the argument index that is equal to `v`, otherwise -1. @@ -3501,6 +3582,15 @@ class SymbolDuplicator { if( !node_duplicator.success ) { return nullptr; } + if (ASR::is_a(*m_type)) { + ASR::Struct_t* st = ASR::down_cast(m_type); + std::string derived_type_name = ASRUtils::symbol_name(st->m_derived_type); + ASR::symbol_t* derived_type_sym = destination_symtab->resolve_symbol(derived_type_name); + LCOMPILERS_ASSERT_MSG( derived_type_sym != nullptr, "derived_type_sym cannot be nullptr"); + if (derived_type_sym != st->m_derived_type) { + st->m_derived_type = derived_type_sym; + } + } return ASR::down_cast( ASR::make_Variable_t(al, variable->base.base.loc, destination_symtab, variable->m_name, variable->m_dependencies, variable->n_dependencies, @@ -3585,6 +3675,11 @@ class SymbolDuplicator { if( new_return_var ) { node_duplicator.success = true; new_return_var = node_duplicator.duplicate_expr(function->m_return_var); + if (ASR::is_a(*new_return_var)) { + ASR::Var_t* var = ASR::down_cast(new_return_var); + std::string var_sym_name = ASRUtils::symbol_name(var->m_v); + new_return_var = ASRUtils::EXPR(make_Var_t(al, var->base.base.loc, function_symtab->get_symbol(var_sym_name))); + } if( !node_duplicator.success ) { return nullptr; } @@ -4317,6 +4412,9 @@ inline ASR::asr_t* make_ArrayConstant_t_util(Allocator &al, const Location &a_lo return ASR::make_ArrayConstant_t(al, a_loc, a_args, n_args, a_type, a_storage_format); } +void make_ArrayBroadcast_t_util(Allocator& al, const Location& loc, + ASR::expr_t*& expr1, ASR::expr_t*& expr2); + static inline void Call_t_body(Allocator& al, ASR::symbol_t* a_name, ASR::call_arg_t* a_args, size_t n_args, ASR::expr_t* a_dt, ASR::stmt_t** cast_stmt, bool implicit_argument_casting) { bool is_method = a_dt != nullptr; @@ -4463,8 +4561,10 @@ static inline void Call_t_body(Allocator& al, ASR::symbol_t* a_name, } } if( ASRUtils::is_array(arg_type) && ASRUtils::is_array(orig_arg_type) ) { - ASR::Array_t* arg_array_t = ASR::down_cast(ASRUtils::type_get_past_const(arg_type)); - ASR::Array_t* orig_arg_array_t = ASR::down_cast(ASRUtils::type_get_past_const(orig_arg_type)); + ASR::Array_t* arg_array_t = ASR::down_cast( + ASRUtils::type_get_past_pointer(ASRUtils::type_get_past_const(arg_type))); + ASR::Array_t* orig_arg_array_t = ASR::down_cast( + ASRUtils::type_get_past_pointer(ASRUtils::type_get_past_const(orig_arg_type))); if( (arg_array_t->m_physical_type != orig_arg_array_t->m_physical_type) || (arg_array_t->m_physical_type == ASR::array_physical_typeType::DescriptorArray && arg_array_t->m_physical_type == orig_arg_array_t->m_physical_type && @@ -4629,6 +4729,8 @@ inline ASR::ttype_t* make_Pointer_t_util(Allocator& al, const Location& loc, ASR return ASRUtils::TYPE(ASR::make_Pointer_t(al, loc, type)); } +int64_t compute_trailing_zeros(int64_t number); + } // namespace ASRUtils } // namespace LCompilers diff --git a/src/libasr/asr_verify.cpp b/src/libasr/asr_verify.cpp index 785478765d7..7f6dc8484b0 100644 --- a/src/libasr/asr_verify.cpp +++ b/src/libasr/asr_verify.cpp @@ -441,20 +441,19 @@ class VerifyVisitor : public BaseWalkVisitor verify_unique_dependencies(x.m_dependencies, x.n_dependencies, x.m_name, x.base.base.loc); - // Get the x symtab. - SymbolTable *x_symtab = x.m_symtab; + // Get the x parent symtab. + SymbolTable *x_parent_symtab = x.m_symtab->parent; // Dependencies of the function should be from function's parent symbol table. for( size_t i = 0; i < x.n_dependencies; i++ ) { std::string found_dep = x.m_dependencies[i]; // Get the symbol of the found_dep. - ASR::symbol_t* dep_sym = x_symtab->resolve_symbol(found_dep); + ASR::symbol_t* dep_sym = x_parent_symtab->resolve_symbol(found_dep); require(dep_sym != nullptr, "Dependency " + found_dep + " is inside symbol table " + std::string(x.m_name)); } - // Check if there are unnecessary dependencies // present in the dependency list of the function for( size_t i = 0; i < x.n_dependencies; i++ ) { @@ -891,10 +890,16 @@ class VerifyVisitor : public BaseWalkVisitor SymbolTable* temp_scope = current_symtab; - if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() && - !ASR::is_a(*asr_owner_sym) && !ASR::is_a(*x.m_name) && - !ASR::is_a(*x.m_name)) { - function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name))); + if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() && + !ASR::is_a(*x.m_name) && !ASR::is_a(*x.m_name)) { + if (ASR::is_a(*asr_owner_sym) || ASR::is_a(*asr_owner_sym)) { + temp_scope = temp_scope->parent; + if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter()) { + function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name))); + } + } else { + function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name))); + } } if( ASR::is_a(*x.m_name) ) { @@ -1037,9 +1042,15 @@ class VerifyVisitor : public BaseWalkVisitor SymbolTable* temp_scope = current_symtab; if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() && - !ASR::is_a(*asr_owner_sym) && !ASR::is_a(*x.m_name) && - !ASR::is_a(*x.m_name)) { - function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name))); + !ASR::is_a(*x.m_name) && !ASR::is_a(*x.m_name)) { + if (ASR::is_a(*asr_owner_sym) || ASR::is_a(*asr_owner_sym)) { + temp_scope = temp_scope->parent; + if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter()) { + function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name))); + } + } else { + function_dependencies.push_back(std::string(ASRUtils::symbol_name(x.m_name))); + } } if( ASR::is_a(*x.m_name) ) { diff --git a/src/libasr/codegen/KaleidoscopeJIT.h b/src/libasr/codegen/KaleidoscopeJIT.h index 1fe7a057e0b..28829bcad61 100644 --- a/src/libasr/codegen/KaleidoscopeJIT.h +++ b/src/libasr/codegen/KaleidoscopeJIT.h @@ -26,6 +26,12 @@ #include "llvm/IR/LLVMContext.h" #include +#if LLVM_VERSION_MAJOR >= 16 +# define RM_OPTIONAL_TYPE std::optional +#else +# define RM_OPTIONAL_TYPE llvm::Optional +#endif + namespace llvm { namespace orc { @@ -71,7 +77,7 @@ class KaleidoscopeJIT { auto CPU = "generic"; auto Features = ""; TargetOptions opt; - auto RM = Optional(); + auto RM = RM_OPTIONAL_TYPE(); TM = Target->createTargetMachine(TargetTriple, CPU, Features, opt, RM); } diff --git a/src/libasr/codegen/asr_to_c.cpp b/src/libasr/codegen/asr_to_c.cpp index e855255559a..00bb14ce192 100644 --- a/src/libasr/codegen/asr_to_c.cpp +++ b/src/libasr/codegen/asr_to_c.cpp @@ -19,7 +19,7 @@ #include #define CHECK_FAST_C(compiler_options, x) \ - if (compiler_options.fast && x.m_value != nullptr) { \ + if (compiler_options.po.fast && x.m_value != nullptr) { \ visit_expr(*x.m_value); \ return; \ } \ @@ -83,12 +83,19 @@ class ASRToCVisitor : public BaseCCPPVisitor ASR::dimension_t* m_dims, int n_dims, bool use_ref, bool dummy, bool declare_value, bool is_fixed_size, - bool is_pointer=false, - ASR::abiType m_abi=ASR::abiType::Source) { + bool is_pointer, + ASR::abiType m_abi, + bool is_simd_array) { std::string indent(indentation_level*indentation_spaces, ' '); std::string type_name_copy = type_name; + std::string original_type_name = type_name; type_name = c_ds_api->get_array_type(type_name, encoded_type_name, array_types_decls); std::string type_name_without_ptr = c_ds_api->get_array_type(type_name, encoded_type_name, array_types_decls, false); + if (is_simd_array) { + int64_t size = ASRUtils::get_fixed_size_of_array(m_dims, n_dims); + sub = original_type_name + " " + v_m_name + " __attribute__ (( vector_size(sizeof(" + original_type_name + ") * " + std::to_string(size) + ") ))"; + return; + } if( declare_value ) { std::string variable_name = std::string(v_m_name) + "_value"; sub = format_type_c("", type_name_without_ptr, variable_name, use_ref, dummy) + ";\n"; @@ -201,6 +208,9 @@ class ASRToCVisitor : public BaseCCPPVisitor } bool is_module_var = ASR::is_a( *ASR::down_cast(v.m_parent_symtab->asr_owner)); + bool is_simd_array = (ASR::is_a(*v.m_type) && + ASR::down_cast(v.m_type)->m_physical_type + == ASR::array_physical_typeType::SIMDArray); generate_array_decl(sub, force_declare_name, type_name, dims, encoded_type_name, m_dims, n_dims, use_ref, dummy, @@ -209,7 +219,7 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_out && v.m_intent != ASRUtils::intent_unspecified && !is_struct_type_member && !is_module_var) || force_declare, - is_fixed_size); + is_fixed_size, false, ASR::abiType::Source, is_simd_array); } } else { bool is_fixed_size = true; @@ -288,7 +298,7 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_in && v.m_intent != ASRUtils::intent_inout && v.m_intent != ASRUtils::intent_out && - v.m_intent != ASRUtils::intent_unspecified, is_fixed_size, true); + v.m_intent != ASRUtils::intent_unspecified, is_fixed_size, true, ASR::abiType::Source, false); } else { bool is_fixed_size = true; std::string dims = convert_dims_c(n_dims, m_dims, v_m_type, is_fixed_size); @@ -311,7 +321,7 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_inout && v.m_intent != ASRUtils::intent_out && v.m_intent != ASRUtils::intent_unspecified, - is_fixed_size, true); + is_fixed_size, true, ASR::abiType::Source, false); } else { bool is_fixed_size = true; std::string dims = convert_dims_c(n_dims, m_dims, v_m_type, is_fixed_size); @@ -336,7 +346,10 @@ class ASRToCVisitor : public BaseCCPPVisitor if( is_array ) { bool is_fixed_size = true; std::string dims = convert_dims_c(n_dims, m_dims, v_m_type, is_fixed_size, true); - std::string encoded_type_name = "f" + std::to_string(t->m_kind * 8); + std::string encoded_type_name = ASRUtils::get_type_code(t2); + bool is_simd_array = (ASR::is_a(*v_m_type) && + ASR::down_cast(v_m_type)->m_physical_type + == ASR::array_physical_typeType::SIMDArray); generate_array_decl(sub, std::string(v.m_name), type_name, dims, encoded_type_name, m_dims, n_dims, use_ref, dummy, @@ -344,7 +357,7 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_inout && v.m_intent != ASRUtils::intent_out && v.m_intent != ASRUtils::intent_unspecified, - is_fixed_size, true); + is_fixed_size, true, ASR::abiType::Source, is_simd_array); } else { bool is_fixed_size = true; std::string dims = convert_dims_c(n_dims, m_dims, v_m_type, is_fixed_size); @@ -365,7 +378,7 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_inout && v.m_intent != ASRUtils::intent_out && v.m_intent != ASRUtils::intent_unspecified, - is_fixed_size); + is_fixed_size, false, ASR::abiType::Source, false); } else { std::string ptr_char = "*"; if( !use_ptr_for_derived_type ) { @@ -436,7 +449,7 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_inout && v.m_intent != ASRUtils::intent_out && v.m_intent != ASRUtils::intent_unspecified, - is_fixed_size); + is_fixed_size, false, ASR::abiType::Source, false); } else if( v.m_intent == ASRUtils::intent_local && pre_initialise_derived_type) { bool is_fixed_size = true; dims = convert_dims_c(n_dims, m_dims, v_m_type, is_fixed_size); @@ -497,7 +510,9 @@ class ASRToCVisitor : public BaseCCPPVisitor v.m_intent != ASRUtils::intent_in && v.m_intent != ASRUtils::intent_inout && v.m_intent != ASRUtils::intent_out && - v.m_intent != ASRUtils::intent_unspecified, is_fixed_size); + v.m_intent != ASRUtils::intent_unspecified, is_fixed_size, + false, + ASR::abiType::Source, false); } else { bool is_fixed_size = true; dims = convert_dims_c(n_dims, m_dims, v_m_type, is_fixed_size); @@ -1189,6 +1204,34 @@ R"( // Initialise Numpy src = this->check_tmp_buffer() + out; } + void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t& x) { + /* + !LF$ attributes simd :: A + real :: A(8) + A = 1 + We need to generate: + a = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + */ + + CHECK_FAST_C(compiler_options, x) + if (x.m_value) { + ASR::expr_t* value = x.m_value; + LCOMPILERS_ASSERT(ASR::is_a(*value)); + ASR::ArrayConstant_t* array_const = ASR::down_cast(value); + std::string array_const_str = "{"; + for( size_t i = 0; i < array_const->n_args; i++ ) { + ASR::expr_t* array_const_arg = array_const->m_args[i]; + this->visit_expr(*array_const_arg); + array_const_str += src + ", "; + } + array_const_str.pop_back(); + array_const_str.pop_back(); + array_const_str += "}"; + + src = array_const_str; + } + } + void visit_ArraySize(const ASR::ArraySize_t& x) { CHECK_FAST_C(compiler_options, x) visit_expr(*x.m_v); diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index db43a981751..9cda6714b16 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -29,7 +29,7 @@ #include #define CHECK_FAST_C_CPP(compiler_options, x) \ - if (compiler_options.fast && x.m_value != nullptr) { \ + if (compiler_options.po.fast && x.m_value != nullptr) { \ self().visit_expr(*x.m_value); \ return; \ } \ @@ -1051,6 +1051,13 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { void visit_ArrayPhysicalCast(const ASR::ArrayPhysicalCast_t& x) { src = ""; this->visit_expr(*x.m_arg); + if (x.m_old == ASR::array_physical_typeType::FixedSizeArray && + x.m_new == ASR::array_physical_typeType::SIMDArray) { + std::string arr_element_type = CUtils::get_c_type_from_ttype_t(ASRUtils::expr_type(x.m_arg)); + int64_t size = ASRUtils::get_fixed_size_of_array(ASRUtils::expr_type(x.m_arg)); + std::string cast = arr_element_type + " __attribute__ (( vector_size(sizeof(" + arr_element_type + ") * " + std::to_string(size) + ") ))"; + src = "(" + cast + ") " + src; + } } std::string construct_call_args(ASR::Function_t* f, size_t n_args, ASR::call_arg_t* m_args) { @@ -1353,7 +1360,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { bool is_target_data_only_array = ASRUtils::is_fixed_size_array(m_target_dims, n_target_dims) && ASR::is_a(*ASRUtils::get_asr_owner(x.m_target)); bool is_value_data_only_array = ASRUtils::is_fixed_size_array(m_value_dims, n_value_dims) && - ASR::is_a(*ASRUtils::get_asr_owner(x.m_value)); + ASRUtils::get_asr_owner(x.m_value) && ASR::is_a(*ASRUtils::get_asr_owner(x.m_value)); if( is_target_data_only_array || is_value_data_only_array ) { int64_t target_size = -1, value_size = -1; if( !is_target_data_only_array ) { @@ -1390,6 +1397,37 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { from_std_vector_helper.clear(); } + void visit_Associate(const ASR::Associate_t &x) { + if (ASR::is_a(*x.m_value)) { + self().visit_expr(*x.m_target); + std::string target = std::move(src); + // ArraySection(expr v, array_index* args, ttype type, expr? value) + ASR::ArraySection_t *as = ASR::down_cast(x.m_value); + self().visit_expr(*as->m_v); + std::string value = std::move(src); + std::string c = ""; + for( size_t i = 0; i < as->n_args; i++ ) { + std::string left, right, step; + if (as->m_args[i].m_left) { + self().visit_expr(*as->m_args[i].m_left); + left = std::move(src); + } + if (as->m_args[i].m_right) { + self().visit_expr(*as->m_args[i].m_right); + right = std::move(src); + } + if (as->m_args[i].m_step) { + self().visit_expr(*as->m_args[i].m_step); + step = std::move(src); + } + c += left + ":" + right + ":" + step + ","; + } + src = target + "= " + value + "; // TODO: " + value + "(" + c + ")\n"; + } else { + throw CodeGenError("Associate only implemented for ArraySection so far"); + } + } + void visit_IntegerConstant(const ASR::IntegerConstant_t &x) { src = std::to_string(x.m_n); last_expr_precedence = 2; @@ -2359,7 +2397,8 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { } } - void visit_Allocate(const ASR::Allocate_t &x) { + template + void handle_alloc_realloc(const T &x) { std::string indent(indentation_level*indentation_spaces, ' '); std::string out = ""; for (size_t i=0; i asr_to_cpp(Allocator &al, ASR::TranslationUnit_t &asr, diag::Diagnostics &diagnostics, CompilerOptions &co, int64_t default_lower_bound) { - LCompilers::PassOptions pass_options; - pass_options.always_run = true; - pass_options.fast = co.fast; - pass_unused_functions(al, asr, pass_options); + co.po.always_run = true; + pass_unused_functions(al, asr, co.po); ASRToCPPVisitor v(diagnostics, co, default_lower_bound); try { v.visit_asr((ASR::asr_t &)asr); diff --git a/src/libasr/codegen/asr_to_fortran.cpp b/src/libasr/codegen/asr_to_fortran.cpp new file mode 100644 index 00000000000..dae695c1dcb --- /dev/null +++ b/src/libasr/codegen/asr_to_fortran.cpp @@ -0,0 +1,1875 @@ +#include +#include +#include +#include +#include + +using LCompilers::ASR::is_a; +using LCompilers::ASR::down_cast; + +namespace LCompilers { + +enum Precedence { + Eqv = 2, + NEqv = 2, + Or = 3, + And = 4, + Not = 5, + CmpOp = 6, + Add = 8, + Sub = 8, + UnaryMinus = 9, + Mul = 10, + Div = 10, + Pow = 11, + Ext = 13, +}; + +class ASRToFortranVisitor : public ASR::BaseVisitor +{ +public: + std::string s; + bool use_colors; + int indent_level; + std::string indent; + int indent_spaces; + // The precedence of the last expression, using the table 10.1 + // in the Fortran 2018 standard + int last_expr_precedence; + std::string format_string; + + // Used for importing struct type inside interface + bool is_interface = false; + std::vector import_struct_type; + +public: + ASRToFortranVisitor(bool _use_colors, int _indent) + : use_colors{_use_colors}, indent_level{0}, + indent_spaces{_indent} + { } + + /********************************** Utils *********************************/ + void inc_indent() { + indent_level++; + indent = std::string(indent_level*indent_spaces, ' '); + } + + void dec_indent() { + indent_level--; + indent = std::string(indent_level*indent_spaces, ' '); + } + + void visit_expr_with_precedence(const ASR::expr_t &x, int current_precedence) { + visit_expr(x); + if (last_expr_precedence == 9 || + last_expr_precedence < current_precedence) { + s = "(" + s + ")"; + } + } + + std::string binop2str(const ASR::binopType type) { + switch (type) { + case (ASR::binopType::Add) : { + last_expr_precedence = Precedence::Add; + return " + "; + } case (ASR::binopType::Sub) : { + last_expr_precedence = Precedence::Sub; + return " - "; + } case (ASR::binopType::Mul) : { + last_expr_precedence = Precedence::Mul; + return "*"; + } case (ASR::binopType::Div) : { + last_expr_precedence = Precedence::Div; + return "/"; + } case (ASR::binopType::Pow) : { + last_expr_precedence = Precedence::Pow; + return "**"; + } default : { + throw LCompilersException("Binop type not implemented"); + } + } + } + + std::string cmpop2str(const ASR::cmpopType type) { + last_expr_precedence = Precedence::CmpOp; + switch (type) { + case (ASR::cmpopType::Eq) : return " == "; + case (ASR::cmpopType::NotEq) : return " /= "; + case (ASR::cmpopType::Lt) : return " < " ; + case (ASR::cmpopType::LtE) : return " <= "; + case (ASR::cmpopType::Gt) : return " > " ; + case (ASR::cmpopType::GtE) : return " >= "; + default : throw LCompilersException("Cmpop type not implemented"); + } + } + + std::string logicalbinop2str(const ASR::logicalbinopType type) { + switch (type) { + case (ASR::logicalbinopType::And) : { + last_expr_precedence = Precedence::And; + return " .and. "; + } case (ASR::logicalbinopType::Or) : { + last_expr_precedence = Precedence::Or; + return " .or. "; + } case (ASR::logicalbinopType::Eqv) : { + last_expr_precedence = Precedence::Eqv; + return " .eqv. "; + } case (ASR::logicalbinopType::NEqv) : { + last_expr_precedence = Precedence::NEqv; + return " .neqv. "; + } default : { + throw LCompilersException("Logicalbinop type not implemented"); + } + } + } + + template + void visit_body(const T &x, std::string &r, bool apply_indent=true) { + if (apply_indent) { + inc_indent(); + } + for (size_t i = 0; i < x.n_body; i++) { + visit_stmt(*x.m_body[i]); + r += s; + } + if (apply_indent) { + dec_indent(); + } + } + + std::string get_type(const ASR::ttype_t *t) { + std::string r = ""; + switch (t->type) { + case ASR::ttypeType::Integer: { + r = "integer("; + r += std::to_string(down_cast(t)->m_kind); + r += ")"; + break; + } case ASR::ttypeType::Real: { + r = "real("; + r += std::to_string(down_cast(t)->m_kind); + r += ")"; + break; + } case ASR::ttypeType::Complex: { + r = "complex("; + r += std::to_string(down_cast(t)->m_kind); + r += ")"; + break; + } case ASR::ttypeType::Character: { + ASR::Character_t *c = down_cast(t); + r = "character(len="; + if(c->m_len > 0) { + r += std::to_string(c->m_len); + } else { + if (c->m_len == -1) { + r += "*"; + } else if (c->m_len == -2) { + r += ":"; + } else if (c->m_len == -3) { + visit_expr(*c->m_len_expr); + r += s; + } + } + r += ", kind="; + r += std::to_string(c->m_kind); + r += ")"; + break; + } case ASR::ttypeType::Logical: { + r = "logical("; + r += std::to_string(down_cast(t)->m_kind); + r += ")"; + break; + } case ASR::ttypeType::Array: { + ASR::Array_t* arr_type = down_cast(t); + std::string bounds = ""; + for (size_t i = 0; i < arr_type->n_dims; i++) { + if (i > 0) bounds += ", "; + std::string start = "", len = ""; + if (arr_type->m_dims[i].m_start) { + visit_expr(*arr_type->m_dims[i].m_start); + start = s; + } + if (arr_type->m_dims[i].m_length) { + visit_expr(*arr_type->m_dims[i].m_length); + len = s; + } + + if (len.length() == 0) { + bounds += ":"; + } else { + if (start.length() == 0 || start == "1") { + bounds += len; + } else { + bounds += start + ":(" + start + ")+(" + len + ")-1"; + } + } + } + r = get_type(arr_type->m_type) + ", dimension(" + bounds + ")"; + break; + } case ASR::ttypeType::Allocatable: { + r = get_type(down_cast(t)->m_type) + ", allocatable"; + break; + } case ASR::ttypeType::Pointer: { + r = get_type(down_cast(t)->m_type) + ", pointer"; + break; + } case ASR::ttypeType::Struct: { + ASR::Struct_t* struct_type = down_cast(t); + std::string struct_name = ASRUtils::symbol_name(struct_type->m_derived_type); + r = "type("; + r += struct_name; + r += ")"; + if (std::find(import_struct_type.begin(), import_struct_type.end(), + struct_name) == import_struct_type.end() && is_interface) { + // Push unique struct names; + import_struct_type.push_back(struct_name); + } + break; + } + default: + throw LCompilersException("The type `" + + ASRUtils::type_to_str_python(t) + "` is not handled yet"); + } + return r; + } + + template + void handle_compare(const T& x) { + std::string r = "", m_op = cmpop2str(x.m_op); + int current_precedence = last_expr_precedence; + visit_expr_with_precedence(*x.m_left, current_precedence); + r += s; + r += m_op; + visit_expr_with_precedence(*x.m_right, current_precedence); + r += s; + last_expr_precedence = current_precedence; + s = r; + } + + /********************************** Unit **********************************/ + void visit_TranslationUnit(const ASR::TranslationUnit_t &x) { + std::string r = ""; + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + visit_symbol(*item.second); + r += s; + } + } + + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + visit_symbol(*item.second); + r += s; + } + } + + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + visit_symbol(*item.second); + r += s; + } + } + s = r; + } + + /********************************* Symbol *********************************/ + void visit_Program(const ASR::Program_t &x) { + std::string r; + r = "program"; + r += " "; + r.append(x.m_name); + r += "\n"; + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + visit_symbol(*item.second); + r += s; + } + } + r += indent + "implicit none"; + r += "\n"; + std::map> struct_dep_graph; + for (auto &item : x.m_symtab->get_scope()) { + if (ASR::is_a(*item.second) || + ASR::is_a(*item.second) || + ASR::is_a(*item.second)) { + std::vector struct_deps_vec; + std::pair struct_deps_ptr = ASRUtils::symbol_dependencies(item.second); + for( size_t i = 0; i < struct_deps_ptr.second; i++ ) { + struct_deps_vec.push_back(std::string(struct_deps_ptr.first[i])); + } + struct_dep_graph[item.first] = struct_deps_vec; + } + } + + std::vector struct_deps = ASRUtils::order_deps(struct_dep_graph); + for (auto &item : struct_deps) { + ASR::symbol_t* struct_sym = x.m_symtab->get_symbol(item); + visit_symbol(*struct_sym); + r += s; + } + std::vector var_order = ASRUtils::determine_variable_declaration_order(x.m_symtab); + for (auto &item : var_order) { + ASR::symbol_t* var_sym = x.m_symtab->get_symbol(item); + if (is_a(*var_sym)) { + visit_symbol(*var_sym); + r += s; + } + } + + visit_body(x, r, false); + + bool prepend_contains_keyword = true; + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + if (prepend_contains_keyword) { + prepend_contains_keyword = false; + r += "\n"; + r += "contains"; + r += "\n\n"; + } + visit_symbol(*item.second); + r += s; + } + } + r += "end program"; + r += " "; + r.append(x.m_name); + r += "\n"; + s = r; + } + + void visit_Module(const ASR::Module_t &x) { + std::string r; + r = "module"; + r += " "; + r.append(x.m_name); + r += "\n"; + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + visit_symbol(*item.second); + r += s; + } + } + r += indent + "implicit none"; + r += "\n"; + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + visit_symbol(*item.second); + r += s; + + } + } + std::map> struct_dep_graph; + for (auto &item : x.m_symtab->get_scope()) { + if (ASR::is_a(*item.second) || + ASR::is_a(*item.second) || + ASR::is_a(*item.second)) { + std::vector struct_deps_vec; + std::pair struct_deps_ptr = ASRUtils::symbol_dependencies(item.second); + for( size_t i = 0; i < struct_deps_ptr.second; i++ ) { + struct_deps_vec.push_back(std::string(struct_deps_ptr.first[i])); + } + struct_dep_graph[item.first] = struct_deps_vec; + } + } + + std::vector struct_deps = ASRUtils::order_deps(struct_dep_graph); + for (auto &item : struct_deps) { + ASR::symbol_t* struct_sym = x.m_symtab->get_symbol(item); + visit_symbol(*struct_sym); + r += s; + } + std::vector var_order = ASRUtils::determine_variable_declaration_order(x.m_symtab); + for (auto &item : var_order) { + ASR::symbol_t* var_sym = x.m_symtab->get_symbol(item); + if (is_a(*var_sym)) { + visit_symbol(*var_sym); + r += s; + } + } + std::vector func_name; + std::vector interface_func_name; + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + ASR::Function_t *f = down_cast(item.second); + if (ASRUtils::get_FunctionType(f)->m_deftype == ASR::deftypeType::Interface) { + interface_func_name.push_back(item.first); + } else { + func_name.push_back(item.first); + } + } + } + for (size_t i = 0; i < interface_func_name.size(); i++) { + if (i == 0) { + r += "interface\n"; + is_interface = true; + inc_indent(); + } + visit_symbol(*x.m_symtab->get_symbol(interface_func_name[i])); + r += s; + if (i < interface_func_name.size() - 1) { + r += "\n"; + } else { + dec_indent(); + is_interface = false; + r += "end interface\n"; + } + } + for (size_t i = 0; i < func_name.size(); i++) { + if (i == 0) { + r += "\n"; + r += "contains"; + r += "\n\n"; + } + visit_symbol(*x.m_symtab->get_symbol(func_name[i])); + r += s; + if (i < func_name.size()) r += "\n"; + } + r += "end module"; + r += " "; + r.append(x.m_name); + r += "\n"; + s = r; + } + + void visit_Function(const ASR::Function_t &x) { + std::string r = indent; + ASR::FunctionType_t *type = ASR::down_cast(x.m_function_signature); + if (type->m_pure) { + r += "pure "; + } + if (type->m_elemental) { + r += "elemental "; + } + bool is_return_var_declared = false; + if (x.m_return_var) { + if (!ASRUtils::is_array(ASRUtils::expr_type(x.m_return_var))) { + is_return_var_declared = true; + r += get_type(ASRUtils::expr_type(x.m_return_var)); + r += " "; + } + r += "function"; + } else { + r += "subroutine"; + } + r += " "; + r.append(x.m_name); + r += "("; + for (size_t i = 0; i < x.n_args; i ++) { + visit_expr(*x.m_args[i]); + r += s; + if (i < x.n_args-1) r += ", "; + } + r += ")"; + if (type->m_abi == ASR::abiType::BindC) { + r += " bind(c"; + if (type->m_bindc_name) { + r += ", name = \""; + r += type->m_bindc_name; + r += "\""; + } + r += ")"; + } + std::string return_var = ""; + if (x.m_return_var) { + LCOMPILERS_ASSERT(is_a(*x.m_return_var)); + visit_expr(*x.m_return_var); + return_var = s; + if (strcmp(x.m_name, return_var.c_str())) { + r += " result(" + return_var + ")"; + } + } + r += "\n"; + + inc_indent(); + { + std::string variable_declaration; + std::vector var_order = ASRUtils::determine_variable_declaration_order(x.m_symtab); + for (auto &item : var_order) { + if (is_return_var_declared && item == return_var) continue; + ASR::symbol_t* var_sym = x.m_symtab->get_symbol(item); + if (is_a(*var_sym)) { + visit_symbol(*var_sym); + variable_declaration += s; + } + } + for (size_t i = 0; i < import_struct_type.size(); i ++) { + if (i == 0) { + r += indent; + r += "import "; + } + r += import_struct_type[i]; + if (i < import_struct_type.size() - 1) { + r += ", "; + } else { + r += "\n"; + } + } + import_struct_type.clear(); + r += variable_declaration; + } + + // Interface + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + ASR::Function_t *f = down_cast(item.second); + if (ASRUtils::get_FunctionType(f)->m_deftype == ASR::deftypeType::Interface) { + is_interface = true; + r += indent; + r += "interface\n"; + inc_indent(); + visit_symbol(*item.second); + r += s; + r += "\n"; + dec_indent(); + r += indent; + r += "end interface\n"; + is_interface = false; + } else { + throw CodeGenError("Nested Function is not handled yet"); + } + } + } + + visit_body(x, r, false); + dec_indent(); + r += indent; + r += "end "; + if (x.m_return_var) { + r += "function"; + } else { + r += "subroutine"; + } + r += " "; + r.append(x.m_name); + r += "\n"; + s = r; + } + + void visit_GenericProcedure(const ASR::GenericProcedure_t &x) { + std::string r = indent; + r += "interface "; + r.append(x.m_name); + r += "\n"; + inc_indent(); + r += indent; + r += "module procedure "; + for (size_t i = 0; i < x.n_procs; i++) { + r += ASRUtils::symbol_name(x.m_procs[i]); + if (i < x.n_procs-1) r += ", "; + } + dec_indent(); + r += "\n"; + r += "end interface "; + r.append(x.m_name); + r += "\n"; + s = r; + } + + // void visit_CustomOperator(const ASR::CustomOperator_t &x) {} + + void visit_ExternalSymbol(const ASR::ExternalSymbol_t &x) { + ASR::symbol_t *sym = down_cast( + ASRUtils::symbol_parent_symtab(x.m_external)->asr_owner); + if (!is_a(*sym)) { + s = indent; + s += "use "; + s.append(x.m_module_name); + s += ", only: "; + s.append(x.m_original_name); + s += "\n"; + } + } + + void visit_StructType(const ASR::StructType_t &x) { + std::string r = indent; + r += "type :: "; + r.append(x.m_name); + r += "\n"; + inc_indent(); + std::vector var_order = ASRUtils::determine_variable_declaration_order(x.m_symtab); + for (auto &item : var_order) { + ASR::symbol_t* var_sym = x.m_symtab->get_symbol(item); + if (is_a(*var_sym)) { + visit_symbol(*var_sym); + r += s; + } + } + dec_indent(); + r += "end type "; + r.append(x.m_name); + r += "\n"; + s = r; + } + + // void visit_EnumType(const ASR::EnumType_t &x) {} + + // void visit_UnionType(const ASR::UnionType_t &x) {} + + void visit_Variable(const ASR::Variable_t &x) { + std::string r = indent; + std::string dims = "("; + r += get_type(x.m_type); + switch (x.m_intent) { + case ASR::intentType::In : { + r += ", intent(in)"; + break; + } case ASR::intentType::InOut : { + r += ", intent(inout)"; + break; + } case ASR::intentType::Out : { + r += ", intent(out)"; + break; + } case ASR::intentType::Local : { + // Pass + break; + } case ASR::intentType::ReturnVar : { + // Pass + break; + } case ASR::intentType::Unspecified : { + // Pass + break; + } + default: + throw LCompilersException("Intent type is not handled"); + } + if (x.m_presence == ASR::presenceType::Optional) { + r += ", optional"; + } + if (x.m_storage == ASR::storage_typeType::Parameter) { + r += ", parameter"; + } else if (x.m_storage == ASR::storage_typeType::Save) { + r += ", save"; + } + if (x.m_value_attr) { + r += ", value"; + } + r += " :: "; + r.append(x.m_name); + if (x.m_value) { + r += " = "; + visit_expr(*x.m_value); + r += s; + } else if (x.m_symbolic_value) { + r += " = "; + visit_expr(*x.m_symbolic_value); + r += s; + } + r += "\n"; + s = r; + } + + // void visit_ClassType(const ASR::ClassType_t &x) {} + + // void visit_ClassProcedure(const ASR::ClassProcedure_t &x) {} + + // void visit_AssociateBlock(const ASR::AssociateBlock_t &x) {} + + // void visit_Block(const ASR::Block_t &x) {} + + // void visit_Requirement(const ASR::Requirement_t &x) {} + + // void visit_Template(const ASR::Template_t &x) {} + + /********************************** Stmt **********************************/ + void visit_Allocate(const ASR::Allocate_t &x) { + std::string r = indent; + r += "allocate("; + for (size_t i = 0; i < x.n_args; i ++) { + visit_expr(*x.m_args[i].m_a); + r += s; + if (x.m_args[i].n_dims > 0) { + r += "("; + for (size_t j = 0; j < x.m_args[i].n_dims; j ++) { + visit_expr(*x.m_args[i].m_dims[j].m_length); + r += s; + if (j < x.m_args[i].n_dims-1) r += ", "; + } + r += ")"; + } + } + r += ")\n"; + s = r; + } + + // void visit_ReAlloc(const ASR::ReAlloc_t &x) {} + + void visit_Assign(const ASR::Assign_t &x) { + std::string r; + r += "assign"; + r += " "; + r += x.m_label; + r += " "; + r += "to"; + r += " "; + r += x.m_variable; + r += "\n"; + s = r; + } + + void visit_Assignment(const ASR::Assignment_t &x) { + std::string r = indent; + visit_expr(*x.m_target); + r += s; + r += " = "; + visit_expr(*x.m_value); + r += s; + r += "\n"; + s = r; + } + + void visit_Associate(const ASR::Associate_t &x) { + visit_expr(*x.m_target); + std::string t = std::move(s); + visit_expr(*x.m_value); + std::string v = std::move(s); + s = t + " => " + v + "\n"; + } + + void visit_Cycle(const ASR::Cycle_t &x) { + s = indent + "cycle"; + if (x.m_stmt_name) { + s += " " + std::string(x.m_stmt_name); + } + s += "\n"; + } + + // void visit_ExplicitDeallocate(const ASR::ExplicitDeallocate_t &x) {} + + void visit_ImplicitDeallocate(const ASR::ImplicitDeallocate_t &x) { + std::string r = indent; + r += "deallocate("; + for (size_t i = 0; i < x.n_vars; i ++) { + visit_expr(*x.m_vars[i]); + r += s; + if (i < x.n_vars-1) r += ", "; + } + r += ") "; + r += "! Implicit deallocate\n"; + s = r; + } + + // void visit_DoConcurrentLoop(const ASR::DoConcurrentLoop_t &x) {} + + void visit_DoLoop(const ASR::DoLoop_t &x) { + std::string r = indent; + if (x.m_name) { + r += std::string(x.m_name); + r += " : "; + } + + r += "do "; + visit_expr(*x.m_head.m_v); + r += s; + r += " = "; + visit_expr(*x.m_head.m_start); + r += s; + r += ", "; + visit_expr(*x.m_head.m_end); + r += s; + if (x.m_head.m_increment) { + r += ", "; + visit_expr(*x.m_head.m_increment); + r += s; + } + r += "\n"; + visit_body(x, r); + r += indent; + r += "end do"; + if (x.m_name) { + r += " " + std::string(x.m_name); + } + r += "\n"; + s = r; + } + + void visit_ErrorStop(const ASR::ErrorStop_t &/*x*/) { + s = indent; + s += "error stop"; + s += "\n"; + } + + void visit_Exit(const ASR::Exit_t &x) { + s = indent + "exit"; + if (x.m_stmt_name) { + s += " " + std::string(x.m_stmt_name); + } + s += "\n"; + } + + // void visit_ForAllSingle(const ASR::ForAllSingle_t &x) {} + + void visit_GoTo(const ASR::GoTo_t &x) { + std::string r = indent; + r += "go to"; + r += " "; + r += std::to_string(x.m_target_id); + r += "\n"; + s = r; + } + + void visit_GoToTarget(const ASR::GoToTarget_t &x) { + std::string r = ""; + r += std::to_string(x.m_id); + r += " "; + r += "continue"; + r += "\n"; + s = r; + } + + void visit_If(const ASR::If_t &x) { + std::string r = indent; + r += "if"; + r += " ("; + visit_expr(*x.m_test); + r += s; + r += ") "; + r += "then"; + r += "\n"; + visit_body(x, r); + for (size_t i = 0; i < x.n_orelse; i++) { + r += indent; + r += "else"; + r += "\n"; + inc_indent(); + visit_stmt(*x.m_orelse[i]); + r += s; + dec_indent(); + } + r += indent; + r += "end if"; + r += "\n"; + s = r; + } + + // void visit_IfArithmetic(const ASR::IfArithmetic_t &x) {} + + void visit_Print(const ASR::Print_t &x) { + std::string r = indent; + r += "print"; + r += " "; + if (x.n_values > 0 && is_a(*x.m_values[0])) { + ASR::StringFormat_t *sf = down_cast(x.m_values[0]); + visit_expr(*sf->m_fmt); + if (is_a(*sf->m_fmt) + && (!startswith(s, "\"(") || !endswith(s, ")\""))) { + s = "\"(" + s.substr(1, s.size()-2) + ")\""; + } + r += s; + } else { + r += "*"; + } + for (size_t i = 0; i < x.n_values; i++) { + r += ", "; + visit_expr(*x.m_values[i]); + r += s; + } + r += "\n"; + s = r; + } + + void visit_FileOpen(const ASR::FileOpen_t &x) { + std::string r; + r = indent; + r += "open"; + r += "("; + if (x.m_newunit) { + visit_expr(*x.m_newunit); + r += s; + } else { + throw CodeGenError("open() function must be called with a file unit number"); + } + if (x.m_filename) { + r += ", "; + r += "file="; + visit_expr(*x.m_filename); + r += s; + } + if (x.m_status) { + r += ", "; + r += "status="; + visit_expr(*x.m_status); + r += s; + } + if (x.m_form) { + r += ", "; + r += "form="; + visit_expr(*x.m_form); + r += s; + } + r += ")"; + r += "\n"; + s = r; + } + + void visit_FileClose(const ASR::FileClose_t &x) { + std::string r; + r = indent; + r += "close"; + r += "("; + if (x.m_unit) { + visit_expr(*x.m_unit); + r += s; + } else { + throw CodeGenError("close() function must be called with a file unit number"); + } + r += ")"; + r += "\n"; + s = r; + } + + void visit_FileRead(const ASR::FileRead_t &x) { + std::string r; + r = indent; + r += "read"; + r += "("; + if (x.m_unit) { + visit_expr(*x.m_unit); + r += s; + } else { + r += "*"; + } + if (x.m_fmt) { + r += ", "; + r += "fmt="; + visit_expr(*x.m_fmt); + r += s; + } else { + r += ", *"; + } + if (x.m_iomsg) { + r += ", "; + r += "iomsg="; + visit_expr(*x.m_iomsg); + r += s; + } + if (x.m_iostat) { + r += ", "; + r += "iostat="; + visit_expr(*x.m_iostat); + r += s; + } + if (x.m_id) { + r += ", "; + r += "id="; + visit_expr(*x.m_id); + r += s; + } + r += ") "; + for (size_t i = 0; i < x.n_values; i++) { + visit_expr(*x.m_values[i]); + r += s; + if (i < x.n_values - 1) r += ", "; + } + r += "\n"; + s = r; + } + + // void visit_FileBackspace(const ASR::FileBackspace_t &x) {} + + // void visit_FileRewind(const ASR::FileRewind_t &x) {} + + // void visit_FileInquire(const ASR::FileInquire_t &x) {} + + void visit_FileWrite(const ASR::FileWrite_t &x) { + std::string r = indent; + r += "write"; + r += "("; + if (!x.m_unit) { + r += "*, "; + } + if (x.n_values > 0 && is_a(*x.m_values[0])) { + ASR::StringFormat_t *sf = down_cast(x.m_values[0]); + visit_expr(*sf->m_fmt); + if (is_a(*sf->m_fmt) + && (!startswith(s, "\"(") || !endswith(s, ")\""))) { + s = "\"(" + s.substr(1, s.size()-2) + ")\""; + } + r += s; + } else { + r += "*"; + } + r += ") "; + for (size_t i = 0; i < x.n_values; i++) { + visit_expr(*x.m_values[i]); + r += s; + if (i < x.n_values-1) r += ", "; + } + r += "\n"; + s = r; + } + + void visit_Return(const ASR::Return_t &/*x*/) { + std::string r = indent; + r += "return"; + r += "\n"; + s = r; + } + + void visit_Select(const ASR::Select_t &x) { + std::string r = indent; + r += "select case"; + r += " ("; + visit_expr(*x.m_test); + r += s; + r += ")\n"; + inc_indent(); + if (x.n_body > 0) { + for(size_t i = 0; i < x.n_body; i ++) { + visit_case_stmt(*x.m_body[i]); + r += s; + } + } + + if (x.n_default > 0) { + r += indent; + r += "case default\n"; + inc_indent(); + for(size_t i = 0; i < x.n_default; i ++) { + visit_stmt(*x.m_default[i]); + r += s; + } + dec_indent(); + } + dec_indent(); + r += indent; + r += "end select\n"; + s = r; + } + + void visit_Stop(const ASR::Stop_t /*x*/) { + s = indent; + s += "stop"; + s += "\n"; + } + + // void visit_Assert(const ASR::Assert_t &x) {} + + void visit_SubroutineCall(const ASR::SubroutineCall_t &x) { + std::string r = indent; + r += "call "; + r += ASRUtils::symbol_name(x.m_name); + r += "("; + for (size_t i = 0; i < x.n_args; i ++) { + visit_expr(*x.m_args[i].m_value); + r += s; + if (i < x.n_args-1) r += ", "; + } + r += ")\n"; + s = r; + } + + void visit_Where(const ASR::Where_t &x) { + std::string r; + r = indent; + r += "where"; + r += " "; + r += "("; + visit_expr(*x.m_test); + r += s; + r += ")\n"; + visit_body(x, r); + for (size_t i = 0; i < x.n_orelse; i++) { + r += indent; + r += "else where"; + r += "\n"; + inc_indent(); + visit_stmt(*x.m_orelse[i]); + r += s; + dec_indent(); + } + r += indent; + r += "end where"; + r += "\n"; + s = r; + } + + void visit_WhileLoop(const ASR::WhileLoop_t &x) { + std::string r = indent; + if (x.m_name) { + r += std::string(x.m_name); + r += " : "; + } + r += "do while"; + r += " ("; + visit_expr(*x.m_test); + r += s; + r += ")\n"; + visit_body(x, r); + r += indent; + r += "end do"; + if (x.m_name) { + r += " " + std::string(x.m_name); + } + r += "\n"; + s = r; + } + + // void visit_Nullify(const ASR::Nullify_t &x) {} + + // void visit_Flush(const ASR::Flush_t &x) {} + + // void visit_AssociateBlockCall(const ASR::AssociateBlockCall_t &x) {} + + // void visit_SelectType(const ASR::SelectType_t &x) {} + + // void visit_CPtrToPointer(const ASR::CPtrToPointer_t &x) {} + + // void visit_BlockCall(const ASR::BlockCall_t &x) {} + + // void visit_Expr(const ASR::Expr_t &x) {} + + /********************************** Expr **********************************/ + // void visit_IfExp(const ASR::IfExp_t &x) {} + + void visit_ComplexConstructor(const ASR::ComplexConstructor_t &x) { + visit_expr(*x.m_re); + std::string re = s; + visit_expr(*x.m_im); + std::string im = s; + s = "(" + re + ", " + im + ")"; + } + + // void visit_NamedExpr(const ASR::NamedExpr_t &x) {} + + void visit_FunctionCall(const ASR::FunctionCall_t &x) { + std::string r = ""; + if (x.m_original_name) { + r += ASRUtils::symbol_name(x.m_original_name); + } else { + r += ASRUtils::symbol_name(x.m_name); + } + if (r == "bit_size") { + // TODO: Remove this once bit_size is implemented in IntrinsicScalarFunction + visit_expr(*x.m_value); + return; + } + + r += "("; + for (size_t i = 0; i < x.n_args; i ++) { + visit_expr(*x.m_args[i].m_value); + r += s; + if (i < x.n_args-1) r += ", "; + } + r += ")"; + s = r; + } + + void visit_IntrinsicScalarFunction(const ASR::IntrinsicScalarFunction_t &x) { + std::string out; + switch (x.m_intrinsic_id) { + SET_INTRINSIC_NAME(Abs, "abs"); + SET_INTRINSIC_NAME(Exp, "exp"); + SET_INTRINSIC_NAME(Max, "max"); + SET_INTRINSIC_NAME(Min, "min"); + SET_INTRINSIC_NAME(Sqrt, "sqrt"); + default : { + throw LCompilersException("IntrinsicScalarFunction: `" + + ASRUtils::get_intrinsic_name(x.m_intrinsic_id) + + "` is not implemented"); + } + } + LCOMPILERS_ASSERT(x.n_args == 1); + visit_expr(*x.m_args[0]); + out += "(" + s + ")"; + s = out; + } + + #define SET_ARR_INTRINSIC_NAME(X, func_name) \ + case (static_cast(ASRUtils::IntrinsicArrayFunctions::X)) : { \ + visit_expr(*x.m_args[0]); \ + out += func_name; break; \ + } + + void visit_IntrinsicArrayFunction(const ASR::IntrinsicArrayFunction_t &x) { + std::string out; + switch (x.m_arr_intrinsic_id) { + SET_ARR_INTRINSIC_NAME(Any, "any"); + SET_ARR_INTRINSIC_NAME(Sum, "sum"); + SET_ARR_INTRINSIC_NAME(Shape, "shape"); + default : { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_array_intrinsic_name(x.m_arr_intrinsic_id) + + "` is not implemented"); + } + } + out += "(" + s + ")"; + s = out; + } + + // void visit_IntrinsicImpureFunction(const ASR::IntrinsicImpureFunction_t &x) {} + + void visit_StructTypeConstructor(const ASR::StructTypeConstructor_t &x) { + std::string r = indent; + r += ASRUtils::symbol_name(x.m_dt_sym); + r += "("; + for(size_t i = 0; i < x.n_args; i++) { + visit_expr(*x.m_args[i].m_value); + r += s; + if (i < x.n_args - 1) r += ", "; + } + r += ")"; + s = r; + } + + // void visit_EnumTypeConstructor(const ASR::EnumTypeConstructor_t &x) {} + + // void visit_UnionTypeConstructor(const ASR::UnionTypeConstructor_t &x) {} + + // void visit_ImpliedDoLoop(const ASR::ImpliedDoLoop_t &x) {} + + void visit_IntegerConstant(const ASR::IntegerConstant_t &x) { + s = std::to_string(x.m_n); + last_expr_precedence = Precedence::Ext; + } + + // void visit_IntegerBOZ(const ASR::IntegerBOZ_t &x) {} + + // void visit_IntegerBitNot(const ASR::IntegerBitNot_t &x) {} + + void visit_IntegerUnaryMinus(const ASR::IntegerUnaryMinus_t &x) { + visit_expr_with_precedence(*x.m_arg, 9); + s = "-" + s; + last_expr_precedence = Precedence::UnaryMinus; + } + + void visit_IntegerCompare(const ASR::IntegerCompare_t &x) { + handle_compare(x); + } + + void visit_IntegerBinOp(const ASR::IntegerBinOp_t &x) { + std::string r = "", m_op = binop2str(x.m_op); + int current_precedence = last_expr_precedence; + visit_expr_with_precedence(*x.m_left, current_precedence); + r += s; + r += m_op; + visit_expr_with_precedence(*x.m_right, current_precedence); + if ((x.m_op == ASR::binopType::Sub && last_expr_precedence <= 8) || + (x.m_op == ASR::binopType::Div && last_expr_precedence <= 10)) { + s = "(" + s + ")"; + } + r += s; + last_expr_precedence = current_precedence; + s = r; + } + + // void visit_UnsignedIntegerConstant(const ASR::UnsignedIntegerConstant_t &x) {} + + // void visit_UnsignedIntegerUnaryMinus(const ASR::UnsignedIntegerUnaryMinus_t &x) {} + + // void visit_UnsignedIntegerBitNot(const ASR::UnsignedIntegerBitNot_t &x) {} + + // void visit_UnsignedIntegerCompare(const ASR::UnsignedIntegerCompare_t &x) {} + + // void visit_UnsignedIntegerBinOp(const ASR::UnsignedIntegerBinOp_t &x) {} + + void visit_RealConstant(const ASR::RealConstant_t &x) { + int kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + if (kind >= 8) { + s = std::to_string(x.m_r) + "d0"; + } else { + s = std::to_string(x.m_r); + } + last_expr_precedence = Precedence::Ext; + } + + void visit_RealUnaryMinus(const ASR::RealUnaryMinus_t &x) { + visit_expr_with_precedence(*x.m_arg, 9); + s = "-" + s; + last_expr_precedence = Precedence::UnaryMinus; + } + + void visit_RealCompare(const ASR::RealCompare_t &x) { + handle_compare(x); + } + + void visit_RealBinOp(const ASR::RealBinOp_t &x) { + std::string r = "", m_op = binop2str(x.m_op); + int current_precedence = last_expr_precedence; + visit_expr_with_precedence(*x.m_left, current_precedence); + r += s; + r += m_op; + visit_expr_with_precedence(*x.m_right, current_precedence); + r += s; + last_expr_precedence = current_precedence; + s = r; + } + + // void visit_RealCopySign(const ASR::RealCopySign_t &x) {} + + void visit_ComplexConstant(const ASR::ComplexConstant_t &x) { + std::string re = std::to_string(x.m_re); + std::string im = std::to_string(x.m_im); + s = "(" + re + ", " + im + ")"; + } + + void visit_ComplexUnaryMinus(const ASR::ComplexUnaryMinus_t &x) { + visit_expr_with_precedence(*x.m_arg, 9); + s = "-" + s; + last_expr_precedence = Precedence::UnaryMinus; + } + + void visit_ComplexCompare(const ASR::ComplexCompare_t &x) { + handle_compare(x); + } + + void visit_ComplexBinOp(const ASR::ComplexBinOp_t &x) { + std::string r = "", m_op = binop2str(x.m_op); + int current_precedence = last_expr_precedence; + visit_expr_with_precedence(*x.m_left, current_precedence); + r += s; + r += m_op; + visit_expr_with_precedence(*x.m_right, current_precedence); + r += s; + last_expr_precedence = current_precedence; + s = r; + } + + void visit_LogicalConstant(const ASR::LogicalConstant_t &x) { + s = "."; + if (x.m_value) { + s += "true"; + } else { + s += "false"; + } + s += "."; + last_expr_precedence = Precedence::Ext; + } + + void visit_LogicalNot(const ASR::LogicalNot_t &x) { + visit_expr_with_precedence(*x.m_arg, 5); + s = ".not. " + s; + last_expr_precedence = Precedence::Not; + } + + void visit_LogicalCompare(const ASR::LogicalCompare_t &x) { + handle_compare(x); + } + + void visit_LogicalBinOp(const ASR::LogicalBinOp_t &x) { + std::string r = "", m_op = logicalbinop2str(x.m_op); + int current_precedence = last_expr_precedence; + visit_expr_with_precedence(*x.m_left, current_precedence); + r += s; + r += m_op; + visit_expr_with_precedence(*x.m_right, current_precedence); + r += s; + last_expr_precedence = current_precedence; + s = r; + } + + void visit_StringConstant(const ASR::StringConstant_t &x) { + s = "\""; + s.append(x.m_s); + s += "\""; + last_expr_precedence = Precedence::Ext; + } + + void visit_StringConcat(const ASR::StringConcat_t &x) { + this->visit_expr(*x.m_left); + std::string left = std::move(s); + this->visit_expr(*x.m_right); + std::string right = std::move(s); + s = left + "//" + right; + } + + void visit_StringRepeat(const ASR::StringRepeat_t &x) { + this->visit_expr(*x.m_left); + std::string str = s; + this->visit_expr(*x.m_right); + std::string n = s; + s = "repeat(" + str + ", " + n + ")"; + } + + void visit_StringLen(const ASR::StringLen_t &x) { + visit_expr(*x.m_arg); + s = "len(" + s + ")"; + } + + void visit_StringItem(const ASR::StringItem_t &x) { + std::string r = ""; + this->visit_expr(*x.m_arg); + r += s; + r += "("; + this->visit_expr(*x.m_idx); + r += s; + r += ":"; + r += s; + r += ")"; + s = r; + } + + // void visit_StringSection(const ASR::StringSection_t &x) {} + + void visit_StringCompare(const ASR::StringCompare_t &x) { + handle_compare(x); + } + + // void visit_StringOrd(const ASR::StringOrd_t &x) {} + + void visit_StringChr(const ASR::StringChr_t &x) { + visit_expr(*x.m_arg); + s = "char(" + s + ")"; + } + + void visit_StringFormat(const ASR::StringFormat_t &x) { + std::string r = ""; + if (format_string.size() > 0) { + visit_expr(*x.m_fmt); + format_string = s; + } + for (size_t i = 0; i < x.n_args; i++) { + visit_expr(*x.m_args[i]); + r += s; + if (i < x.n_args-1) r += ", "; + } + s = r; + } + + // void visit_CPtrCompare(const ASR::CPtrCompare_t &x) {} + + // void visit_SymbolicCompare(const ASR::SymbolicCompare_t &x) {} + + void visit_Var(const ASR::Var_t &x) { + s = ASRUtils::symbol_name(x.m_v); + last_expr_precedence = Precedence::Ext; + } + + // void visit_FunctionParam(const ASR::FunctionParam_t &x) {} + + void visit_ArrayConstant(const ASR::ArrayConstant_t &x) { + std::string r = "["; + for(size_t i = 0; i < x.n_args; i++) { + visit_expr(*x.m_args[i]); + r += s; + if (i < x.n_args-1) r += ", "; + } + r += "]"; + s = r; + last_expr_precedence = Precedence::Ext; + } + + void visit_ArrayItem(const ASR::ArrayItem_t &x) { + std::string r = ""; + visit_expr(*x.m_v); + r += s; + r += "("; + for(size_t i = 0; i < x.n_args; i++) { + if (x.m_args[i].m_right) { + visit_expr(*x.m_args[i].m_right); + r += s; + } + if (i < x.n_args-1) r += ", "; + } + r += ")"; + s = r; + last_expr_precedence = Precedence::Ext; + } + + void visit_ArraySection(const ASR::ArraySection_t &x) { + std::string r = ""; + visit_expr(*x.m_v); + r += s; + r += "("; + for (size_t i = 0; i < x.n_args; i++) { + if (i > 0) { + r += ", "; + } + std::string left, right, step; + if (x.m_args[i].m_left) { + visit_expr(*x.m_args[i].m_left); + left = std::move(s); + r += left + ":"; + } + if (x.m_args[i].m_right) { + visit_expr(*x.m_args[i].m_right); + right = std::move(s); + r += right; + } + if (x.m_args[i].m_step ) { + visit_expr(*x.m_args[i].m_step); + step = std::move(s); + if (step != "1") { + r += ":" + step; + } + } + } + r += ")"; + s = r; + last_expr_precedence = Precedence::Ext; + } + + void visit_ArraySize(const ASR::ArraySize_t &x) { + visit_expr(*x.m_v); + std::string r = "size(" + s; + if (x.m_dim) { + r += ", "; + visit_expr(*x.m_dim); + r += s; + } + r += ")"; + s = r; + } + + void visit_ArrayBound(const ASR::ArrayBound_t &x) { + std::string r = ""; + if (x.m_bound == ASR::arrayboundType::UBound) { + r += "ubound("; + } else if (x.m_bound == ASR::arrayboundType::LBound) { + r += "lbound("; + } + visit_expr(*x.m_v); + r += s; + r += ", "; + visit_expr(*x.m_dim); + r += s; + r += ")"; + s = r; + } + + void visit_ArrayTranspose(const ASR::ArrayTranspose_t &x) { + visit_expr(*x.m_matrix); + s = "transpose(" + s + ")"; + } + + void visit_ArrayPack(const ASR::ArrayPack_t &x) { + std::string r; + r += "pack"; + r += "("; + visit_expr(*x.m_array); + r += s; + r += ", "; + visit_expr(*x.m_mask); + r += s; + if (x.m_vector) { + r += ", "; + visit_expr(*x.m_vector); + r += s; + } + r += ")"; + s = r; + } + + void visit_ArrayReshape(const ASR::ArrayReshape_t &x) { + std::string r; + r += "reshape("; + visit_expr(*x.m_array); + r += s; + r += ", "; + visit_expr(*x.m_shape); + r += s; + r += ")"; + s = r; + } + + void visit_ArrayAll(const ASR::ArrayAll_t &x) { + std::string r; + r += "all"; + r += "("; + visit_expr(*x.m_mask); + r += s; + if (x.m_dim) { + visit_expr(*x.m_dim); + r += s; + } + r += ")"; + s = r; + } + + // void visit_BitCast(const ASR::BitCast_t &x) {} + + void visit_StructInstanceMember(const ASR::StructInstanceMember_t &x) { + std::string r; + visit_expr(*x.m_v); + r += s; + r += "%"; + r += ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(x.m_m)); + s = r; + } + + // void visit_StructStaticMember(const ASR::StructStaticMember_t &x) {} + + // void visit_EnumStaticMember(const ASR::EnumStaticMember_t &x) {} + + // void visit_UnionInstanceMember(const ASR::UnionInstanceMember_t &x) {} + + // void visit_EnumName(const ASR::EnumName_t &x) {} + + // void visit_EnumValue(const ASR::EnumValue_t &x) {} + + // void visit_OverloadedCompare(const ASR::OverloadedCompare_t &x) {} + + // void visit_OverloadedBinOp(const ASR::OverloadedBinOp_t &x) {} + + // void visit_OverloadedUnaryMinus(const ASR::OverloadedUnaryMinus_t &x) {} + + void visit_Cast(const ASR::Cast_t &x) { + std::string r; + visit_expr(*x.m_arg); + switch (x.m_kind) { + case (ASR::cast_kindType::IntegerToReal) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 1: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 2: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 4: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast IntegerToReal: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::RealToInteger) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 1: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + case 2: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + case 4: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast RealToInteger: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::RealToReal) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 1: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 2: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 4: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast RealToReal: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::IntegerToInteger) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 1: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + case 2: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + case 4: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "int(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast IntegerToInteger: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::ComplexToComplex) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 4: r = "cmplx(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "cmplx(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast ComplexToComplex: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::IntegerToComplex) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 4: r = "cmplx(" + s + ", " + "0.0" + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "cmplx(" + s + ", " + "0.0" + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast IntegerToComplex: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::ComplexToReal) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 4: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast ComplexToReal: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::RealToComplex) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 4: r = "cmplx(" + s + ", " + "0.0" + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "cmplx(" + s + ", " + "0.0" + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast IntegerToComplex: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::LogicalToInteger) : { + s = "int(" + s + ")"; + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::LogicalToCharacter) : { + s = "char(" + s + ")"; + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::IntegerToLogical) : { + // Implicit conversion between integer -> logical + break; + } + case (ASR::cast_kindType::LogicalToReal) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 4: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: r = "real(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast LogicalToReal: Unsupported Kind " + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::RealToLogical) : { + s = "(bool)(" + s + ")"; + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::CharacterToLogical) : { + s = "(bool)(len(" + s + ") > 0)"; + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::ComplexToLogical) : { + s = "(bool)(" + s + ")"; + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::IntegerToCharacter) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 1: s = "char(" + s + ", " + "kind=dest_kind" + ")"; break; + case 2: s = "char(" + s + ", " + "kind=dest_kind" + ")"; break; + case 4: s = "char(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: s = "char(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast IntegerToCharacter: Unsupported Kind " + \ + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + case (ASR::cast_kindType::CharacterToInteger) : { + int dest_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type); + switch (dest_kind) { + case 1: s = "ichar(" + s + ", " + "kind=dest_kind" + ")"; break; + case 2: s = "ichar(" + s + ", " + "kind=dest_kind" + ")"; break; + case 4: s = "ichar(" + s + ", " + "kind=dest_kind" + ")"; break; + case 8: s = "ichar(" + s + ", " + "kind=dest_kind" + ")"; break; + default: throw CodeGenError("Cast CharacterToInteger: Unsupported Kind " + \ + std::to_string(dest_kind)); + } + last_expr_precedence = 2; + break; + } + default : { + throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented", + x.base.base.loc); + } + } + } + + void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t &x) { + // TODO + visit_expr(*x.m_array); + } + + void visit_ArrayPhysicalCast(const ASR::ArrayPhysicalCast_t &x) { + this->visit_expr(*x.m_arg); + } + + void visit_ComplexRe(const ASR::ComplexRe_t &x) { + visit_expr(*x.m_arg); + s = "real(" + s + ")"; + } + + void visit_ComplexIm(const ASR::ComplexIm_t &x) { + visit_expr(*x.m_arg); + s = "aimag(" + s + ")"; + } + + // void visit_CLoc(const ASR::CLoc_t &x) {} + + // void visit_PointerToCPtr(const ASR::PointerToCPtr_t &x) {} + + // void visit_GetPointer(const ASR::GetPointer_t &x) {} + + void visit_IntegerBitLen(const ASR::IntegerBitLen_t &x) { + visit_expr(*x.m_a); + s = "bit_size(" + s + ")"; + } + + void visit_Ichar(const ASR::Ichar_t &x) { + visit_expr(*x.m_arg); + s = "ichar(" + s + ")"; + } + + void visit_Iachar(const ASR::Iachar_t &x) { + visit_expr(*x.m_arg); + s = "iachar(" + s + ")"; + } + + // void visit_SizeOfType(const ASR::SizeOfType_t &x) {} + + // void visit_PointerNullConstant(const ASR::PointerNullConstant_t &x) {} + + // void visit_PointerAssociated(const ASR::PointerAssociated_t &x) {} + + void visit_IntrinsicFunctionSqrt(const ASR::IntrinsicFunctionSqrt_t &x) { + visit_expr(*x.m_arg); + s = "sqrt(" + s + ")"; + } + + /******************************* Case Stmt ********************************/ + void visit_CaseStmt(const ASR::CaseStmt_t &x) { + std::string r = indent; + r += "case ("; + for(size_t i = 0; i < x.n_test; i ++) { + visit_expr(*x.m_test[i]); + r += s; + if (i < x.n_test-1) r += ", "; + } + r += ")\n"; + inc_indent(); + for(size_t i = 0; i < x.n_body; i ++) { + visit_stmt(*x.m_body[i]); + r += s; + } + dec_indent(); + s = r; + } + + void visit_CaseStmt_Range(const ASR::CaseStmt_Range_t &x) { + std::string r = indent; + r += "case ("; + if (x.m_start) { + visit_expr(*x.m_start); + r += s; + } + r += ":"; + if (x.m_end) { + visit_expr(*x.m_end); + r += s; + } + r += ")\n"; + inc_indent(); + for(size_t i = 0; i < x.n_body; i ++) { + visit_stmt(*x.m_body[i]); + r += s; + } + dec_indent(); + s = r; + } + +}; + +Result asr_to_fortran(ASR::TranslationUnit_t &asr, + diag::Diagnostics &diagnostics, bool color, int indent) { + ASRToFortranVisitor v(color, indent); + try { + v.visit_TranslationUnit(asr); + } catch (const CodeGenError &e) { + diagnostics.diagnostics.push_back(e.d); + return Error(); + } + return v.s; +} + +} // namespace LCompilers diff --git a/src/libasr/codegen/asr_to_fortran.h b/src/libasr/codegen/asr_to_fortran.h new file mode 100644 index 00000000000..906ac7a6716 --- /dev/null +++ b/src/libasr/codegen/asr_to_fortran.h @@ -0,0 +1,15 @@ +#ifndef LFORTRAN_ASR_TO_FORTRAN_H +#define LFORTRAN_ASR_TO_FORTRAN_H + +#include +#include + +namespace LCompilers { + + // Converts ASR to Fortran source code + Result asr_to_fortran(ASR::TranslationUnit_t &asr, + diag::Diagnostics &diagnostics, bool color, int indent); + +} // namespace LCompilers + +#endif // LFORTRAN_ASR_TO_FORTRAN_H diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index ae4bac53ab1..bccc27863a7 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -87,8 +87,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor //! To be used by visit_StructInstanceMember. std::string current_der_type_name; - //! Helpful for debugging while testing LLVM code - void print_util(llvm::Value* v, std::string fmt_chars, std::string endline="\t") { + //! Helpful for debugging while testing LLVM code + void print_util(llvm::Value* v, std::string fmt_chars, std::string endline="\t") { + // Usage: + // print_util(tmp, "%d") // `tmp` to be an integer type std::vector args; std::vector fmt; args.push_back(v); @@ -106,6 +108,26 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor printf(context, *module, *builder, printf_args); } + //! Helpful for debugging while testing LLVM code + void print_util(llvm::Value* v, std::string endline="\n") { + // Usage: + // print_util(tmp) + std::string buf; + llvm::raw_string_ostream os(buf); + v->print(os); + std::cout << os.str() << endline; + } + + //! Helpful for debugging while testing LLVM code + void print_util(llvm::Type* v, std::string endline="\n") { + // Usage: + // print_util(tmp->getType()) + std::string buf; + llvm::raw_string_ostream os(buf); + v->print(os); + std::cout << os.str() << endline; + } + public: diag::Diagnostics &diag; llvm::LLVMContext &context; @@ -178,6 +200,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::unique_ptr set_api_lp; std::unique_ptr set_api_sc; std::unique_ptr arr_descr; + std::vector heap_arrays; + std::map strings_to_be_allocated; // (array, size) + Vec strings_to_be_deallocated; ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile, CompilerOptions &compiler_options_, diag::Diagnostics &diagnostics) : @@ -205,7 +230,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor set_api_sc(std::make_unique(context, llvm_utils.get(), builder.get())), arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context, builder.get(), llvm_utils.get(), - LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor)) + LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor, compiler_options_, heap_arrays)) { llvm_utils->tuple_api = tuple_api.get(); llvm_utils->list_api = list_api.get(); @@ -216,6 +241,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->dict_api_sc = dict_api_sc.get(); llvm_utils->set_api_lp = set_api_lp.get(); llvm_utils->set_api_sc = set_api_sc.get(); + strings_to_be_deallocated.reserve(al, 1); + } + + llvm::AllocaInst* CreateAlloca(llvm::Type* type, + llvm::Value* size=nullptr, const std::string& Name="") { + llvm::BasicBlock &entry_block = builder->GetInsertBlock()->getParent()->getEntryBlock(); + llvm::IRBuilder<> builder0(context); + builder0.SetInsertPoint(&entry_block, entry_block.getFirstInsertionPt()); + return builder0.CreateAlloca(type, size, Name); } llvm::Value* CreateLoad(llvm::Value *x) { @@ -245,7 +279,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // to our new block builder->CreateBr(bb); } +#if LLVM_VERSION_MAJOR >= 16 + fn->insert(fn->end(), bb); +#else fn->getBasicBlockList().push_back(bb); +#endif builder->SetInsertPoint(bb); } @@ -285,6 +323,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } // end + loop_head.pop_back(); + loop_head_names.pop_back(); loop_or_block_end.pop_back(); loop_or_block_end_names.pop_back(); start_new_block(loopend); @@ -400,6 +440,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::dimension_t* m_dims, int n_dims, bool is_data_only=false, bool reserve_data_memory=true) { std::vector> llvm_dims; + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 2; for( int r = 0; r < n_dims; r++ ) { ASR::dimension_t m_dim = m_dims[r]; visit_expr(*(m_dim.m_start)); @@ -408,6 +450,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* end = tmp; llvm_dims.push_back(std::make_pair(start, end)); } + ptr_loads = ptr_loads_copy; if( is_data_only ) { if( !ASRUtils::is_fixed_size_array(m_dims, n_dims) ) { llvm::Value* const_1 = llvm::ConstantInt::get(context, llvm::APInt(32, 1)); @@ -416,11 +459,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* dim_size = llvm_dims[r].second; prod = builder->CreateMul(prod, dim_size); } - llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, prod); - builder->CreateStore(arr_first, arr); + llvm::Value* arr_first = nullptr; + if( !compiler_options.stack_arrays ) { + llvm::DataLayout data_layout(module.get()); + uint64_t size = data_layout.getTypeAllocSize(llvm_data_type); + prod = builder->CreateMul(prod, + llvm::ConstantInt::get(context, llvm::APInt(32, size))); + llvm::Value* arr_first_i8 = LLVMArrUtils::lfortran_malloc( + context, *module, *builder, prod); + heap_arrays.push_back(arr_first_i8); + arr_first = builder->CreateBitCast( + arr_first_i8, llvm_data_type->getPointerTo()); + } else { + arr_first = builder->CreateAlloca(llvm_data_type, prod); + builder->CreateStore(arr_first, arr); + } } } else { - arr_descr->fill_array_details(arr, llvm_data_type, n_dims, llvm_dims, reserve_data_memory); + arr_descr->fill_array_details(arr, llvm_data_type, n_dims, + llvm_dims, module.get(), reserve_data_memory); } } @@ -529,6 +586,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor nullptr); std::vector args = {pleft_arg, pright_arg, presult}; builder->CreateCall(fn, args); + strings_to_be_deallocated.push_back(al, CreateLoad(presult)); return CreateLoad(presult); } @@ -1024,9 +1082,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor return free_fn; } - inline void call_lfortran_free_string(llvm::Function* fn) { - std::vector args = {tmp}; - builder->CreateCall(fn, args); + inline void call_lcompilers_free_strings() { + // if (strings_to_be_deallocated.n > 0) { + // llvm::Function* free_fn = _Deallocate(); + // for( auto &value: strings_to_be_deallocated ) { + // builder->CreateCall(free_fn, {value}); + // } + // strings_to_be_deallocated.reserve(al, 1); + // } } llvm::Function* _Allocate(bool realloc_lhs) { @@ -1393,6 +1456,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor return; } this->visit_expr(*x.m_arg); + if (tmp->getType()->isPointerTy()) { + tmp = CreateLoad(tmp); + } llvm::Value *c = tmp; int64_t kind_value = ASRUtils::extract_kind_from_ttype_t(ASRUtils::expr_type(x.m_arg)); std::string func_name; @@ -2216,6 +2282,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor p = CreateGEP(str, idx_vec); } else { p = lfortran_str_item(str, idx); + strings_to_be_deallocated.push_back(al, p); } // TODO: Currently the string starts at the right location, but goes to the end of the original string. // We have to allocate a new string, copy it and add null termination. @@ -2491,8 +2558,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } this->visit_expr(*x.m_v); ptr_loads = ptr_loads_copy; - if( ASR::is_a(*ASRUtils::type_get_past_pointer(x_m_v_type)) ) { - tmp = CreateLoad(llvm_utils->create_gep(tmp, 1)); + if( ASR::is_a(*ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(x_m_v_type))) ) { + if (ASRUtils::is_allocatable(x_m_v_type)) { + tmp = llvm_utils->create_gep(CreateLoad(tmp), 1); + } else { + tmp = CreateLoad(llvm_utils->create_gep(tmp, 1)); + } if( current_select_type_block_type ) { tmp = builder->CreateBitCast(tmp, current_select_type_block_type); current_der_type_name = current_select_type_block_der_type; @@ -2512,29 +2584,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor current_der_type_name = dertype2parent[current_der_type_name]; } int member_idx = name2memidx[current_der_type_name][member_name]; - std::vector idx_vec = { - llvm::ConstantInt::get(context, llvm::APInt(32, 0)), - llvm::ConstantInt::get(context, llvm::APInt(32, member_idx))}; - // if( (ASR::is_a(*x.m_v) || - // ASR::is_a(*x.m_v)) && - // is_nested_pointer(tmp) ) { - // tmp = CreateLoad(tmp); - // } - llvm::Value* tmp1 = CreateGEP(tmp, idx_vec); - ASR::ttype_t* member_type = member->m_type; - if( ASR::is_a(*member_type) ) { - member_type = ASR::down_cast(member_type)->m_type; - } - if( member_type->type == ASR::ttypeType::Struct ) { - ASR::Struct_t* der = (ASR::Struct_t*)(&(member_type->base)); - ASR::StructType_t* der_type = (ASR::StructType_t*)(&(der->m_derived_type->base)); - current_der_type_name = std::string(der_type->m_name); + + tmp = llvm_utils->create_gep(tmp, member_idx); + ASR::ttype_t* member_type = ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(member->m_type)); + if( ASR::is_a(*member_type) ) { + ASR::symbol_t *s_sym = ASR::down_cast( + member_type)->m_derived_type; + current_der_type_name = ASRUtils::symbol_name( + ASRUtils::symbol_get_past_external(s_sym)); uint32_t h = get_hash((ASR::asr_t*)member); if( llvm_symtab.find(h) != llvm_symtab.end() ) { tmp = llvm_symtab[h]; } + } else if ( ASR::is_a(*member_type) ) { + ASR::symbol_t *s_sym = ASR::down_cast( + member_type)->m_class_type; + current_der_type_name = ASRUtils::symbol_name( + ASRUtils::symbol_get_past_external(s_sym)); } - tmp = tmp1; } void visit_Variable(const ASR::Variable_t &x) { @@ -2625,6 +2693,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor module->getNamedGlobal(x.m_name)->setInitializer( llvm::Constant::getNullValue(character_type) ); + ASR::Character_t *t = down_cast(x.m_type); + if( t->m_len >= 0 ) { + strings_to_be_allocated.insert(std::pair(ptr, llvm::ConstantInt::get( + context, llvm::APInt(32, t->m_len+1)))); + } } } llvm_symtab[h] = ptr; @@ -2902,6 +2975,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_Program(const ASR::Program_t &x) { + loop_head.clear(); + loop_head_names.clear(); + loop_or_block_end.clear(); + loop_or_block_end_names.clear(); + heap_arrays.clear(); + strings_to_be_deallocated.reserve(al, 1); SymbolTable* current_scope_copy = current_scope; current_scope = x.m_symtab; bool is_dict_present_copy_lp = dict_api_lp->is_dict_present(); @@ -2963,9 +3042,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } declare_vars(x); + for(auto &value: strings_to_be_allocated) { + llvm::Value *init_value = LLVM::lfortran_malloc(context, *module, + *builder, value.second); + string_init(context, *module, *builder, value.second, init_value); + builder->CreateStore(init_value, value.first); + } for (size_t i=0; ivisit_stmt(*x.m_body[i]); } + for( auto& value: heap_arrays ) { + LLVM::lfortran_free(context, *module, *builder, value); + } + call_lcompilers_free_strings(); + llvm::Value *ret_val2 = llvm::ConstantInt::get(context, llvm::APInt(32, 0)); builder->CreateRet(ret_val2); @@ -2977,6 +3067,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // Finalize the debug info. if (compiler_options.emit_debug_info) DBuilder->finalize(); current_scope = current_scope_copy; + loop_head.clear(); + loop_head_names.clear(); + loop_or_block_end.clear(); + loop_or_block_end_names.clear(); + heap_arrays.clear(); + strings_to_be_deallocated.reserve(al, 1); } /* @@ -3005,18 +3101,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASRUtils::type_get_past_allocatable(m_type))); llvm::Type* llvm_data_type = llvm_utils->get_type_from_ttype_t_util(asr_data_type, module.get()); llvm::Value* ptr_ = nullptr; - if( is_malloc_array_type && m_type->type != ASR::ttypeType::Pointer && - !is_list && !is_data_only ) { + if( is_malloc_array_type && !is_list && !is_data_only ) { ptr_ = builder->CreateAlloca(type_, nullptr, "arr_desc"); arr_descr->fill_dimension_descriptor(ptr_, n_dims); } if( is_array_type && !is_malloc_array_type && - m_type->type != ASR::ttypeType::Pointer && !is_list ) { fill_array_details(ptr, llvm_data_type, m_dims, n_dims, is_data_only); } if( is_array_type && is_malloc_array_type && - m_type->type != ASR::ttypeType::Pointer && !is_list && !is_data_only ) { // Set allocatable arrays as unallocated LCOMPILERS_ASSERT(ptr_ != nullptr); @@ -3086,7 +3179,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor is_malloc_array_type, is_list, m_dims, n_dims, a_kind, module.get()); llvm::Type* type_ = llvm_utils->get_type_from_ttype_t_util( - ASRUtils::type_get_past_allocatable(v->m_type), module.get(), v->m_abi); + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(v->m_type)), module.get(), v->m_abi); fill_array_details_(ptr_member, type_, m_dims, n_dims, is_malloc_array_type, is_array_type, is_list, v->m_type); } else { @@ -3100,7 +3194,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor void allocate_array_members_of_struct_arrays(llvm::Value* ptr, ASR::ttype_t* v_m_type) { ASR::array_physical_typeType phy_type = ASRUtils::extract_physical_type(v_m_type); - llvm::Value* array_size = builder->CreateAlloca( + llvm::Value* array_size = CreateAlloca( llvm::Type::getInt32Ty(context), nullptr, "array_size"); switch( phy_type ) { case ASR::array_physical_typeType::FixedSizeArray: { @@ -3119,7 +3213,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor LCOMPILERS_ASSERT(false); } } - llvm::Value* llvmi = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr, "i"); + llvm::Value* llvmi = CreateAlloca(llvm::Type::getInt32Ty(context), nullptr, "i"); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), llvmi); create_loop(nullptr, [=]() { @@ -3269,7 +3363,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor v->m_type, v->m_type_declaration, v->m_storage, is_array_type, is_malloc_array_type, is_list, m_dims, n_dims, a_kind, module.get()); llvm::Type* type_ = llvm_utils->get_type_from_ttype_t_util( - ASRUtils::type_get_past_allocatable(v->m_type), module.get(), v->m_abi); + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(v->m_type)), module.get(), v->m_abi); /* * The following if block is used for converting any @@ -3277,16 +3372,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor * can be passed as an argument in a function call in LLVM IR. */ if( x.class_type == ASR::symbolType::Function) { - std::uint32_t m_h; std::string m_name = std::string(x.m_name); - ASR::abiType abi_type = ASR::abiType::Source; - bool is_v_arg = false; - if( x.class_type == ASR::symbolType::Function ) { - ASR::Function_t* _func = (ASR::Function_t*)(&(x.base)); - m_h = get_hash((ASR::asr_t*)_func); - abi_type = ASRUtils::get_FunctionType(_func)->m_abi; - is_v_arg = is_argument(v, _func->m_args, _func->n_args); - } + ASR::Function_t* _func = (ASR::Function_t*)(&(x.base)); + std::uint32_t m_h = get_hash((ASR::asr_t*)_func); + ASR::abiType abi_type = ASRUtils::get_FunctionType(_func)->m_abi; + bool is_v_arg = is_argument(v, _func->m_args, _func->n_args); if( is_array_type && !is_list ) { /* The first element in an array descriptor can be either of * llvm::ArrayType or llvm::PointerType. However, a @@ -3326,7 +3416,29 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } ptr_loads = ptr_loads_copy; } - llvm::AllocaInst *ptr = builder->CreateAlloca(type, array_size, v->m_name); + llvm::Value *ptr = nullptr; + if( !compiler_options.stack_arrays && array_size ) { + llvm::DataLayout data_layout(module.get()); + uint64_t size = data_layout.getTypeAllocSize(type); + array_size = builder->CreateMul(array_size, + llvm::ConstantInt::get(context, llvm::APInt(32, size))); + llvm::Value* ptr_i8 = LLVMArrUtils::lfortran_malloc( + context, *module, *builder, array_size); + heap_arrays.push_back(ptr_i8); + ptr = builder->CreateBitCast(ptr_i8, type->getPointerTo()); + } else { + if (v->m_storage == ASR::storage_typeType::Save) { + std::string parent_function_name = std::string(x.m_name); + std::string global_name = parent_function_name+ "." + v->m_name; + ptr = module->getOrInsertGlobal(global_name, type); + llvm::GlobalVariable *gptr = module->getNamedGlobal(global_name); + gptr->setLinkage(llvm::GlobalValue::InternalLinkage); + llvm::Constant *init_value = llvm::Constant::getNullValue(type); + gptr->setInitializer(init_value); + } else { + ptr = builder->CreateAlloca(type, array_size, v->m_name); + } + } set_pointer_variable_to_null(llvm::ConstantPointerNull::get( static_cast(type)), ptr) if( ASR::is_a( @@ -3366,7 +3478,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor int64_t alignment_value = -1; if( ASRUtils::extract_value(struct_type->m_alignment, alignment_value) ) { llvm::Align align(alignment_value); - ptr->setAlignment(align); + reinterpret_cast(ptr)->setAlignment(align); } } @@ -3427,6 +3539,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } else if (ASR::is_a(*v->m_symbolic_value)) { builder->CreateStore(LLVM::CreateLoad(*builder, init_value), target_var); + } else if (is_a(*v->m_type) && !is_array_type) { + ASR::Character_t *t = down_cast(v->m_type); + llvm::Value *arg_size = llvm::ConstantInt::get(context, + llvm::APInt(32, t->m_len+1)); + llvm::Value *s_malloc = LLVM::lfortran_malloc(context, *module, *builder, arg_size); + string_init(context, *module, *builder, arg_size, s_malloc); + builder->CreateStore(s_malloc, target_var); + tmp = lfortran_str_copy(target_var, init_value); + if (v->m_intent == intent_local) { + strings_to_be_deallocated.push_back(al, CreateLoad(target_var)); + } } else { builder->CreateStore(init_value, target_var); } @@ -3435,24 +3558,28 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::Character_t *t = down_cast(v->m_type); target_var = ptr; int strlen = t->m_len; - if (strlen >= 0) { - // Compile time length - std::string empty(strlen, ' '); - llvm::Value *init_value = builder->CreateGlobalStringPtr(s2c(al, empty)); + if (strlen >= 0 || strlen == -3) { + llvm::Value *arg_size; + if (strlen == -3) { + LCOMPILERS_ASSERT(t->m_len_expr) + this->visit_expr(*t->m_len_expr); + arg_size = builder->CreateAdd(tmp, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + } else { + // Compile time length + arg_size = llvm::ConstantInt::get(context, + llvm::APInt(32, strlen+1)); + } + llvm::Value *init_value = LLVM::lfortran_malloc(context, *module, *builder, arg_size); + string_init(context, *module, *builder, arg_size, init_value); builder->CreateStore(init_value, target_var); + if (v->m_intent == intent_local) { + strings_to_be_deallocated.push_back(al, CreateLoad(target_var)); + } } else if (strlen == -2) { // Allocatable string. Initialize to `nullptr` (unallocated) llvm::Value *init_value = llvm::Constant::getNullValue(type); builder->CreateStore(init_value, target_var); - } else if (strlen == -3) { - LCOMPILERS_ASSERT(t->m_len_expr) - this->visit_expr(*t->m_len_expr); - llvm::Value *arg_size = tmp; - arg_size = builder->CreateAdd(arg_size, llvm::ConstantInt::get(context, llvm::APInt(32, 1))); - // TODO: this temporary string is never deallocated (leaks memory) - llvm::Value *init_value = LLVM::lfortran_malloc(context, *module, *builder, arg_size); - string_init(context, *module, *builder, arg_size, init_value); - builder->CreateStore(init_value, target_var); } else { throw CodeGenError("Unsupported len value in ASR"); } @@ -3524,6 +3651,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_Function(const ASR::Function_t &x) { + loop_head.clear(); + loop_head_names.clear(); + loop_or_block_end.clear(); + loop_or_block_end_names.clear(); + heap_arrays.clear(); + strings_to_be_deallocated.reserve(al, 1); SymbolTable* current_scope_copy = current_scope; current_scope = x.m_symtab; bool is_dict_present_copy_lp = dict_api_lp->is_dict_present(); @@ -3552,6 +3685,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // Finalize the debug info. if (compiler_options.emit_debug_info) DBuilder->finalize(); current_scope = current_scope_copy; + loop_head.clear(); + loop_head_names.clear(); + loop_or_block_end.clear(); + loop_or_block_end_names.clear(); + heap_arrays.clear(); + strings_to_be_deallocated.reserve(al, 1); } void instantiate_function(const ASR::Function_t &x){ @@ -3699,9 +3838,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ret_val2 = tmp; } } + for( auto& value: heap_arrays ) { + LLVM::lfortran_free(context, *module, *builder, value); + } + call_lcompilers_free_strings(); builder->CreateRet(ret_val2); } else { start_new_block(proc_return); + for( auto& value: heap_arrays ) { + LLVM::lfortran_free(context, *module, *builder, value); + } + call_lcompilers_free_strings(); builder->CreateRetVoid(); } } @@ -3982,10 +4129,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor builder0.SetInsertPoint(&entry_block, entry_block.getFirstInsertionPt()); llvm::AllocaInst *res = builder0.CreateAlloca( llvm::Type::getInt1Ty(context), nullptr, "is_associated"); - ASR::Variable_t *p = EXPR2VAR(x.m_ptr); - uint32_t value_h = get_hash((ASR::asr_t*)p); - llvm::Value *ptr = llvm_symtab[value_h], *nptr; - ptr = CreateLoad(ptr); + ASR::ttype_t* p_type = ASRUtils::expr_type(x.m_ptr); + llvm::Value *ptr, *nptr; + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 1; + visit_expr_wrapper(x.m_ptr, true); + ptr = tmp; + ptr_loads = ptr_loads_copy; if( ASR::is_a(*ASRUtils::expr_type(x.m_ptr)) && x.m_tgt && ASR::is_a(*ASRUtils::expr_type(x.m_tgt)) ) { int64_t ptr_loads_copy = ptr_loads; @@ -4029,7 +4179,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ptr = builder->CreatePtrToInt(ptr, llvm_utils->getIntType(8, false)); builder->CreateStore(builder->CreateICmpEQ(ptr, nptr), res); } else { - llvm::Type* value_type = llvm_utils->get_type_from_ttype_t_util(p->m_type, module.get()); + llvm::Type* value_type = llvm_utils->get_type_from_ttype_t_util(p_type, module.get()); nptr = llvm::ConstantPointerNull::get(static_cast(value_type)); nptr = builder->CreatePtrToInt(nptr, llvm_utils->getIntType(8, false)); ptr = builder->CreatePtrToInt(ptr, llvm_utils->getIntType(8, false)); @@ -4304,6 +4454,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = CreateLoad(tmp); } builder->CreateStore(tmp, str); + strings_to_be_deallocated.push_back(al, tmp); } void visit_Assignment(const ASR::Assignment_t &x) { @@ -4327,6 +4478,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_value_struct = ASR::is_a(*asr_value_type); if (ASR::is_a(*x.m_target)) { handle_StringSection_Assignment(x.m_target, x.m_value); + if (tmp == strings_to_be_deallocated.back()) { + strings_to_be_deallocated.erase(strings_to_be_deallocated.back()); + } return; } if( is_target_list && is_value_list ) { @@ -4572,7 +4726,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor if (lhs_is_string_arrayref && value->getType()->isPointerTy()) { value = CreateLoad(value); } - if (ASR::is_a(*x.m_target)) { + if ( (ASR::is_a(*x.m_value) || + ASR::is_a(*x.m_value) || + (ASR::is_a(*x.m_target) + && ASRUtils::is_character(*target_type))) && + !ASR::is_a(*x.m_target) ) { + builder->CreateStore(value, target); + strings_to_be_deallocated.erase(strings_to_be_deallocated.back()); + return; + } else if (ASR::is_a(*x.m_target)) { ASR::Variable_t *asr_target = EXPR2VAR(x.m_target); tmp = lfortran_str_copy(target, value, ASR::is_a(*asr_target->m_type)); @@ -4853,7 +5015,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor start_new_block(blockstart); llvm::BasicBlock *blockend = llvm::BasicBlock::Create(context, blockend_name); llvm::Function *fn = blockstart->getParent(); +#if LLVM_VERSION_MAJOR >= 16 + fn->insert(fn->end(), blockend); +#else fn->getBasicBlockList().push_back(blockend); +#endif builder->SetInsertPoint(blockstart); loop_or_block_end.push_back(blockend); loop_or_block_end_names.push_back(blockend_name); @@ -5354,16 +5520,24 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_If(const ASR::If_t &x) { + llvm::Value **strings_to_be_deallocated_copy = strings_to_be_deallocated.p; + size_t n = strings_to_be_deallocated.n; + strings_to_be_deallocated.reserve(al, 1); this->visit_expr_wrapper(x.m_test, true); - llvm_utils->create_if_else(tmp, [=]() { + llvm_utils->create_if_else(tmp, [&]() { for (size_t i=0; ivisit_stmt(*x.m_body[i]); } - }, [=]() { + call_lcompilers_free_strings(); + }, [&]() { for (size_t i=0; ivisit_stmt(*x.m_orelse[i]); } + call_lcompilers_free_strings(); }); + strings_to_be_deallocated.reserve(al, n); + strings_to_be_deallocated.n = n; + strings_to_be_deallocated.p = strings_to_be_deallocated_copy; } void visit_IfExp(const ASR::IfExp_t &x) { @@ -5371,7 +5545,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(x.m_test, true); llvm::Value *cond = tmp; llvm::Type* _type = llvm_utils->get_type_from_ttype_t_util(x.m_type, module.get()); - llvm::Value* ifexp_res = builder->CreateAlloca(_type); + llvm::Value* ifexp_res = CreateAlloca(_type); llvm_utils->create_if_else(cond, [&]() { this->visit_expr_wrapper(x.m_body, true); builder->CreateStore(tmp, ifexp_res); @@ -5387,14 +5561,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor //} void visit_WhileLoop(const ASR::WhileLoop_t &x) { + llvm::Value **strings_to_be_deallocated_copy = strings_to_be_deallocated.p; + size_t n = strings_to_be_deallocated.n; + strings_to_be_deallocated.reserve(al, 1); create_loop(x.m_name, [=]() { this->visit_expr_wrapper(x.m_test, true); + call_lcompilers_free_strings(); return tmp; - }, [=]() { + }, [&]() { for (size_t i=0; ivisit_stmt(*x.m_body[i]); } + call_lcompilers_free_strings(); }); + strings_to_be_deallocated.reserve(al, n); + strings_to_be_deallocated.n = n; + strings_to_be_deallocated.p = strings_to_be_deallocated_copy; } bool case_insensitive_string_compare(const std::string& str1, const std::string& str2) { @@ -5493,14 +5675,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(x.m_right, true); llvm::Value *right_val = tmp; llvm::Value *zero, *cond; - llvm::AllocaInst *result; if (ASRUtils::is_integer(*x.m_type)) { int a_kind = down_cast(x.m_type)->m_kind; int init_value_bits = 8*a_kind; zero = llvm::ConstantInt::get(context, llvm::APInt(init_value_bits, 0)); cond = builder->CreateICmpEQ(left_val, zero); - result = builder->CreateAlloca(llvm_utils->getIntType(a_kind), nullptr); } else if (ASRUtils::is_real(*x.m_type)) { int a_kind = down_cast(x.m_type)->m_kind; int init_value_bits = 8*a_kind; @@ -5511,39 +5691,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor zero = llvm::ConstantFP::get(context, llvm::APFloat((double)0)); } - result = builder->CreateAlloca(llvm_utils->getFPType(a_kind), nullptr); cond = builder->CreateFCmpUEQ(left_val, zero); } else if (ASRUtils::is_character(*x.m_type)) { zero = llvm::Constant::getNullValue(character_type); cond = lfortran_str_cmp(left_val, zero, "_lpython_str_compare_eq"); - result = builder->CreateAlloca(character_type, nullptr); } else if (ASRUtils::is_logical(*x.m_type)) { zero = llvm::ConstantInt::get(context, llvm::APInt(1, 0)); cond = builder->CreateICmpEQ(left_val, zero); - result = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); } else { throw CodeGenError("Only Integer, Real, Strings and Logical types are supported " "in logical binary operation.", x.base.base.loc); } switch (x.m_op) { case ASR::logicalbinopType::And: { - llvm_utils->create_if_else(cond, [&, result, left_val]() { - LLVM::CreateStore(*builder, left_val, result); - }, [&, result, right_val]() { - LLVM::CreateStore(*builder, right_val, result); - }); - tmp = LLVM::CreateLoad(*builder, result); + tmp = builder->CreateSelect(cond, left_val, right_val); break; }; case ASR::logicalbinopType::Or: { - llvm_utils->create_if_else(cond, [&, result, right_val]() { - LLVM::CreateStore(*builder, right_val, result); - - }, [&, result, left_val]() { - LLVM::CreateStore(*builder, left_val, result); - }); - tmp = LLVM::CreateLoad(*builder, result); + tmp = builder->CreateSelect(cond, right_val, left_val); break; }; case ASR::logicalbinopType::Xor: { @@ -5644,6 +5810,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = CreateGEP(str, idx_vec); } else { tmp = lfortran_str_item(str, idx); + strings_to_be_deallocated.push_back(al, tmp); } } @@ -5758,7 +5925,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor case ASR::binopType::Pow: { llvm::Type *type; int a_kind; - a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; + a_kind = down_cast(ASRUtils::extract_type(x.m_type))->m_kind; type = llvm_utils->getFPType(a_kind); llvm::Value *fleft = builder->CreateSIToFP(left_val, type); @@ -5841,7 +6008,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor case ASR::binopType::Pow: { llvm::Type *type; int a_kind; - a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; + a_kind = down_cast(ASRUtils::extract_type(x.m_type))->m_kind; type = llvm_utils->getFPType(a_kind); std::string func_name = a_kind == 4 ? "llvm.pow.f32" : "llvm.pow.f64"; llvm::Function *fn_pow = module->getFunction(func_name); @@ -6781,7 +6948,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } llvm::Function* get_read_function(ASR::ttype_t *type) { - type = ASRUtils::type_get_past_allocatable(type); + type = ASRUtils::type_get_past_allocatable( + ASRUtils::type_get_past_pointer(type)); llvm::Function *fn = nullptr; switch (type->type) { case (ASR::ttypeType::Integer): { @@ -6964,11 +7132,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::ttype_t* type = ASRUtils::expr_type(x.m_values[i]); llvm::Function *fn = get_read_function(type); if (ASRUtils::is_array(type)) { - if (ASR::is_a(*type)) { + if (ASR::is_a(*type) + || ASR::is_a(*type)) { tmp = CreateLoad(tmp); } tmp = arr_descr->get_pointer_to_data(tmp); - if (ASR::is_a(*type)) { + if (ASR::is_a(*type) + || ASR::is_a(*type)) { tmp = CreateLoad(tmp); } llvm::Value *arr = tmp; @@ -6981,6 +7151,28 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor builder->CreateCall(fn, {tmp, unit_val}); } } + + // In Fortran, read(u, *) is used to read the entire line. The + // next read(u, *) function is intended to read the next entire + // line. Let's take an example: `read(u, *) n`, where n is an + // integer. The first occurance of the integer value will be + // read, and anything after that will be skipped. + // Here, we can use `_lfortran_empty_read` function to move to the + // pointer to the next line. + std::string runtime_func_name = "_lfortran_empty_read"; + llvm::Function *fn = module->getFunction(runtime_func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), { + llvm::Type::getInt32Ty(context), + llvm::Type::getInt32Ty(context)->getPointerTo() + }, false); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, runtime_func_name, + *module); + } + this->visit_expr_wrapper(x.m_unit, true); + builder->CreateCall(fn, {unit_val, iostat}); } } @@ -7107,6 +7299,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor builder->CreateCall(fn, {tmp}); } + void visit_FileBackspace(const ASR::FileBackspace_t &x) { + std::string runtime_func_name = "_lfortran_backspace"; + llvm::Function *fn = module->getFunction(runtime_func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), { + llvm::Type::getInt32Ty(context) + }, false); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, runtime_func_name, *module); + } + this->visit_expr_wrapper(x.m_unit, true); + builder->CreateCall(fn, {tmp}); + } + void visit_FileClose(const ASR::FileClose_t &x) { llvm::Value *unit_val = nullptr; this->visit_expr_wrapper(x.m_unit, true); @@ -7129,15 +7336,88 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_FileWrite(const ASR::FileWrite_t &x) { - if (x.m_fmt != nullptr) { - diag.codegen_warning_label("format string in write() is not implemented yet and it is currently treated as '*'", - {x.m_fmt->base.loc}, "treated as '*'"); + if (x.m_unit == nullptr) { + handle_print(x); + return; } - if (x.m_unit != nullptr) { - diag.codegen_warning_label("unit in write() is not implemented yet and it is currently treated as '*'", - {x.m_unit->base.loc}, "treated as '*'"); + std::vector args; + std::vector args_type; + std::vector fmt; + llvm::Value *sep = nullptr; + llvm::Value *end = nullptr; + llvm::Value *unit = nullptr; + std::string runtime_func_name; + bool is_string = ASRUtils::is_character(*expr_type(x.m_unit)); + + int ptr_loads_copy = ptr_loads; + if ( is_string ) { + ptr_loads = 0; + runtime_func_name = "_lfortran_string_write"; + args_type.push_back(character_type->getPointerTo()); + } else if ( ASRUtils::is_integer(*expr_type(x.m_unit)) ) { + ptr_loads = 1; + runtime_func_name = "_lfortran_file_write"; + args_type.push_back(llvm::Type::getInt32Ty(context)); + } else { + throw CodeGenError("Unsupported type for `unit` in write(..)"); } - handle_print(x); + this->visit_expr_wrapper(x.m_unit); + ptr_loads = ptr_loads_copy; + unit = tmp; + + if (x.m_separator) { + this->visit_expr_wrapper(x.m_separator, true); + sep = tmp; + } else { + sep = builder->CreateGlobalStringPtr(" "); + } + if (x.m_end) { + this->visit_expr_wrapper(x.m_end, true); + end = tmp; + } else { + end = builder->CreateGlobalStringPtr("\n"); + } + size_t n_values = x.n_values; ASR::expr_t **m_values = x.m_values; + // TODO: Handle String Formatting + if (n_values > 0 && is_a(*m_values[0]) && is_string) { + n_values = down_cast(m_values[0])->n_args; + m_values = down_cast(m_values[0])->m_args; + } + for (size_t i=0; iCreateGlobalStringPtr(fmt_str); + + std::vector printf_args; + printf_args.push_back(unit); + printf_args.push_back(fmt_ptr); + printf_args.insert(printf_args.end(), args.begin(), args.end()); + llvm::Function *fn = module->getFunction(runtime_func_name); + if (!fn) { + args_type.push_back(llvm::Type::getInt8PtrTy(context)); + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), args_type, true); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, runtime_func_name, *module); + } + tmp = builder->CreateCall(fn, printf_args); } // It appends the format specifier and arg based on the type of expression @@ -7345,25 +7625,28 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_Stop(const ASR::Stop_t &x) { - if (compiler_options.emit_debug_info) { - debug_emit_loc(x); - llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr(infile); - llvm::Value *fmt_ptr1 = llvm::ConstantInt::get(context, llvm::APInt( - 1, compiler_options.use_colors)); - call_print_stacktrace_addresses(context, *module, *builder, - {fmt_ptr, fmt_ptr1}); - } - llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("STOP\n"); - print_error(context, *module, *builder, {fmt_ptr}); + if (compiler_options.emit_debug_info) debug_emit_loc(x); llvm::Value *exit_code; - if (x.m_code && ASRUtils::expr_type(x.m_code)->type == ASR::ttypeType::Integer) { + if (x.m_code && is_a(*ASRUtils::expr_type(x.m_code))) { this->visit_expr(*x.m_code); exit_code = tmp; + if (compiler_options.emit_debug_info) { + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr(infile); + llvm::Value *fmt_ptr1 = llvm::ConstantInt::get(context, llvm::APInt( + 1, compiler_options.use_colors)); + llvm::Value *test = builder->CreateICmpNE(exit_code, builder->getInt32(0)); + llvm_utils->create_if_else(test, [=]() { + call_print_stacktrace_addresses(context, *module, *builder, + {fmt_ptr, fmt_ptr1}); + }, [](){}); + } } else { int exit_code_int = 0; exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); } + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("STOP\n"); + print_error(context, *module, *builder, {fmt_ptr}); exit(context, *module, *builder, exit_code); } @@ -7406,17 +7689,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor template std::vector convert_call_args(const T &x, bool is_method) { std::vector args; - const ASR::symbol_t* func_subrout = symbol_get_past_external(x.m_name); - ASR::abiType x_abi = ASR::abiType::Source; - if( is_a(*func_subrout) ) { - ASR::Function_t* func = down_cast(func_subrout); - x_abi = ASRUtils::get_FunctionType(func)->m_abi; - } - for (size_t i=0; i if (ASRUtils::get_FunctionType(fn)->m_deftype == ASR::deftypeType::Implementation) { LCOMPILERS_ASSERT(llvm_symtab_fn.find(h) != llvm_symtab_fn.end()); tmp = llvm_symtab_fn[h]; + } else if (llvm_symtab_fn_arg.find(h) == llvm_symtab_fn_arg.end() && + ASR::is_a(*var_sym) && + ASRUtils::get_FunctionType(fn)->m_deftype == ASR::deftypeType::Interface ) { + LCOMPILERS_ASSERT(llvm_symtab_fn.find(h) != llvm_symtab_fn.end()); + tmp = llvm_symtab_fn[h]; + LCOMPILERS_ASSERT(tmp != nullptr) } else { // Must be an argument/chained procedure pass LCOMPILERS_ASSERT(llvm_symtab_fn_arg.find(h) != llvm_symtab_fn_arg.end()); @@ -7725,14 +8006,16 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // using alloca inside a loop, which would // run out of stack if( (ASR::is_a(*x.m_args[i].m_value) || - ASR::is_a(*x.m_args[i].m_value)) + (ASR::is_a(*x.m_args[i].m_value) && + (ASRUtils::is_array(arg_type) || + ASR::is_a(*ASRUtils::expr_type(x.m_args[i].m_value))))) && value->getType()->isPointerTy()) { value = CreateLoad(value); } if( !ASR::is_a(*arg_type) && !(orig_arg && !LLVM::is_llvm_pointer(*orig_arg->m_type) && LLVM::is_llvm_pointer(*arg_type) && - !ASRUtils::is_character(*orig_arg->m_type)) ) { + !ASRUtils::is_character(*orig_arg->m_type)) && !ASR::is_a(*x.m_args[i].m_value) ) { llvm::BasicBlock &entry_block = builder->GetInsertBlock()->getParent()->getEntryBlock(); llvm::IRBuilder<> builder0(context); builder0.SetInsertPoint(&entry_block, entry_block.getFirstInsertionPt()); @@ -8101,6 +8384,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::string m_name = ASRUtils::symbol_name(x.m_name); std::vector args2 = convert_call_args(x, is_method); args.insert(args.end(), args2.begin(), args2.end()); + // check if type of each arg is same as type of each arg in subrout_called + if (ASR::is_a(*symbol_get_past_external(x.m_name))) { + ASR::Function_t* subrout_called = ASR::down_cast(symbol_get_past_external(x.m_name)); + for (size_t i = 0; i < subrout_called->n_args; i++) { + ASR::expr_t* expected_arg = subrout_called->m_args[i]; + ASR::expr_t* passed_arg = x.m_args[i].m_value; + ASR::ttype_t* expected_arg_type = ASRUtils::expr_type(expected_arg); + ASR::ttype_t* passed_arg_type = ASRUtils::expr_type(passed_arg); + if (ASR::is_a(*passed_arg)) { + if (!ASRUtils::types_equal(expected_arg_type, passed_arg_type, true)) { + throw CodeGenError("Type mismatch in subroutine call, expected `" + ASRUtils::type_to_str_python(expected_arg_type) + + "`, passed `" + ASRUtils::type_to_str_python(passed_arg_type) + "`", x.m_args[i].m_value->base.loc); + } + } + } + } if (pass_arg) { args.push_back(pass_arg); } @@ -8363,6 +8662,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_FunctionCall(const ASR::FunctionCall_t &x) { + if ( compiler_options.emit_debug_info ) debug_emit_loc(x); if( ASRUtils::is_intrinsic_optimization(x.m_name) ) { ASR::Function_t* routine = ASR::down_cast( ASRUtils::symbol_get_past_external(x.m_name)); @@ -8485,13 +8785,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor int dt_idx = name2memidx[ASRUtils::symbol_name(struct_sym)] [ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(struct_mem->m_m))]; - llvm::Value* dt_1 = llvm_utils->create_gep( - dt, dt_idx); - dt_1 = llvm_utils->create_gep(dt_1, 1); + llvm::Value* dt_1 = llvm_utils->create_gep(dt, dt_idx); + dt_1 = CreateLoad(llvm_utils->create_gep(CreateLoad(dt_1), 1)); llvm::Value* class_ptr = llvm_utils->create_gep(dt_polymorphic, 1); - if (is_nested_pointer(dt_1)) { - dt_1 = CreateLoad(dt_1); - } builder->CreateStore(dt_1, class_ptr); if (self_argument.length() == 0) { args.push_back(dt_polymorphic); @@ -8655,6 +8951,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } } + if (ASRUtils::is_character(*x.m_type)) { + strings_to_be_deallocated.push_back(al, tmp); + } } void visit_ArraySizeUtil(ASR::expr_t* m_v, ASR::ttype_t* m_type, @@ -8904,25 +9203,11 @@ Result> asr_to_llvm(ASR::TranslationUnit_t &asr, skip_optimization_func_instantiation.push_back(static_cast( ASRUtils::IntrinsicScalarFunctions::SignFromValue)); - pass_options.runtime_library_dir = co.runtime_library_dir; - pass_options.mod_files_dir = co.mod_files_dir; - pass_options.include_dirs = co.include_dirs; - pass_options.run_fun = run_fn; - pass_options.always_run = false; - pass_options.verbose = co.verbose; - pass_options.dump_all_passes = co.dump_all_passes; - pass_options.use_loop_variable_after_loop = co.use_loop_variable_after_loop; - pass_options.realloc_lhs = co.realloc_lhs; - pass_options.skip_optimization_func_instantiation = skip_optimization_func_instantiation; + co.po.run_fun = run_fn; + co.po.always_run = false; + co.po.skip_optimization_func_instantiation = skip_optimization_func_instantiation; pass_manager.rtlib = co.rtlib; - - pass_options.all_symbols_mangling = co.all_symbols_mangling; - pass_options.module_name_mangling = co.module_name_mangling; - pass_options.global_symbols_mangling = co.global_symbols_mangling; - pass_options.intrinsic_symbols_mangling = co.intrinsic_symbols_mangling; - pass_options.bindc_mangling = co.bindc_mangling; - pass_options.mangle_underscore = co.mangle_underscore; - pass_manager.apply_passes(al, &asr, pass_options, diagnostics); + pass_manager.apply_passes(al, &asr, co.po, diagnostics); // Uncomment for debugging the ASR after the transformation // std::cout << LCompilers::pickle(asr, true, false, false) << std::endl; diff --git a/src/libasr/codegen/asr_to_wasm.cpp b/src/libasr/codegen/asr_to_wasm.cpp index 86c2856334a..d83a3511714 100644 --- a/src/libasr/codegen/asr_to_wasm.cpp +++ b/src/libasr/codegen/asr_to_wasm.cpp @@ -3044,12 +3044,6 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { } void visit_Print(const ASR::Print_t &x) { - if (x.m_fmt != nullptr) { - diag.codegen_warning_label( - "format string in `print` is not implemented yet and it is " - "currently treated as '*'", - {x.m_fmt->base.loc}, "treated as '*'"); - } handle_print(x); } @@ -3061,12 +3055,6 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { } void visit_FileWrite(const ASR::FileWrite_t &x) { - if (x.m_fmt != nullptr) { - diag.codegen_warning_label( - "format string in `print` is not implemented yet and it is " - "currently treated as '*'", - {x.m_fmt->base.loc}, "treated as '*'"); - } if (x.m_unit != nullptr) { diag.codegen_error_label("unit in write() is not implemented yet", {x.m_unit->base.loc}, "not implemented"); @@ -3216,15 +3204,12 @@ Result> asr_to_wasm_bytes_stream(ASR::TranslationUnit_t &asr, CompilerOptions &co) { ASRToWASMVisitor v(al, diagnostics); - LCompilers::PassOptions pass_options; - pass_options.always_run = true; - pass_options.verbose = co.verbose; - pass_options.dump_all_passes = co.dump_all_passes; + co.po.always_run = true; std::vector passes = {"pass_array_by_data", "array_op", "implied_do_loops", "print_arr", "do_loops", "select_case", "nested_vars", "unused_functions", "intrinsic_function"}; LCompilers::PassManager pass_manager; - pass_manager.apply_passes(al, &asr, passes, pass_options, diagnostics); + pass_manager.apply_passes(al, &asr, passes, co.po, diagnostics); #ifdef SHOW_ASR diff --git a/src/libasr/codegen/evaluator.cpp b/src/libasr/codegen/evaluator.cpp index 1068f1216a9..f2f1a7fcc00 100644 --- a/src/libasr/codegen/evaluator.cpp +++ b/src/libasr/codegen/evaluator.cpp @@ -178,7 +178,7 @@ LLVMEvaluator::LLVMEvaluator(const std::string &t) std::string CPU = "generic"; std::string features = ""; llvm::TargetOptions opt; - llvm::Optional RM = llvm::Reloc::Model::PIC_; + RM_OPTIONAL_TYPE RM = llvm::Reloc::Model::PIC_; TM = target->createTargetMachine(target_triple, CPU, features, opt, RM); // For some reason the JIT requires a different TargetMachine diff --git a/src/libasr/codegen/llvm_array_utils.cpp b/src/libasr/codegen/llvm_array_utils.cpp index a0edb7f9299..3f884952e72 100644 --- a/src/libasr/codegen/llvm_array_utils.cpp +++ b/src/libasr/codegen/llvm_array_utils.cpp @@ -78,21 +78,20 @@ namespace LCompilers { std::unique_ptr Descriptor::get_descriptor - (llvm::LLVMContext& context, - llvm::IRBuilder<>* builder, - LLVMUtils* llvm_utils, - DESCR_TYPE descr_type) { + (llvm::LLVMContext& context, llvm::IRBuilder<>* builder, + LLVMUtils* llvm_utils, DESCR_TYPE descr_type, + CompilerOptions& co, std::vector& heap_arrays_) { switch( descr_type ) { case DESCR_TYPE::_SimpleCMODescriptor: { - return std::make_unique(context, builder, llvm_utils); + return std::make_unique(context, builder, llvm_utils, co, heap_arrays_); } } return nullptr; } SimpleCMODescriptor::SimpleCMODescriptor(llvm::LLVMContext& _context, - llvm::IRBuilder<>* _builder, - LLVMUtils* _llvm_utils): + llvm::IRBuilder<>* _builder, LLVMUtils* _llvm_utils, CompilerOptions& co_, + std::vector& heap_arrays_): context(_context), llvm_utils(std::move(_llvm_utils)), builder(std::move(_builder)), @@ -103,7 +102,7 @@ namespace LCompilers { llvm::Type::getInt32Ty(context), llvm::Type::getInt32Ty(context)}), "dimension_descriptor") - ) { + ), co(co_), heap_arrays(heap_arrays_) { } bool SimpleCMODescriptor::is_array(ASR::ttype_t* asr_type) { @@ -258,7 +257,7 @@ namespace LCompilers { void SimpleCMODescriptor::fill_array_details( llvm::Value* arr, llvm::Type* llvm_data_type, int n_dims, std::vector>& llvm_dims, - bool reserve_data_memory) { + llvm::Module* module, bool reserve_data_memory) { llvm::Value* offset_val = llvm_utils->create_gep(arr, 1); builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), offset_val); llvm::Value* dim_des_val = llvm_utils->create_gep(arr, 2); @@ -289,8 +288,23 @@ namespace LCompilers { llvm::Value* llvm_size = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); builder->CreateStore(prod, llvm_size); llvm::Value* first_ptr = get_pointer_to_data(arr); - llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, - LLVM::CreateLoad(*builder, llvm_size)); + llvm::Value* arr_first = nullptr; + + if( !co.stack_arrays ) { + llvm::DataLayout data_layout(module); + uint64_t size = data_layout.getTypeAllocSize(llvm_data_type); + builder->CreateStore(builder->CreateMul( + LLVM::CreateLoad(*builder, llvm_size), + llvm::ConstantInt::get(context, llvm::APInt(32, size))), llvm_size); + llvm::Value* arr_first_i8 = lfortran_malloc( + context, *module, *builder, LLVM::CreateLoad(*builder, llvm_size)); + heap_arrays.push_back(arr_first_i8); + arr_first = builder->CreateBitCast( + arr_first_i8, llvm_data_type->getPointerTo()); + } else { + arr_first = builder->CreateAlloca( + llvm_data_type, LLVM::CreateLoad(*builder, llvm_size)); + } builder->CreateStore(arr_first, first_ptr); } @@ -391,7 +405,7 @@ namespace LCompilers { builder->CreateStore(stride, s_val); llvm::Value* l_val = llvm_utils->create_gep(dim_val, 1); llvm::Value* dim_size_ptr = llvm_utils->create_gep(dim_val, 2); - builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), l_val); + builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), l_val); llvm::Value* dim_size = this->get_dimension_size( this->get_pointer_to_dimension_descriptor(source_dim_des_arr, llvm::ConstantInt::get(context, llvm::APInt(32, r)))); @@ -658,15 +672,18 @@ namespace LCompilers { tmp = builder->CreateSExt(tmp, llvm_utils->getIntType(kind)); return tmp; } + llvm::BasicBlock &entry_block = builder->GetInsertBlock()->getParent()->getEntryBlock(); + llvm::IRBuilder<> builder0(context); + builder0.SetInsertPoint(&entry_block, entry_block.getFirstInsertionPt()); llvm::Value* rank = this->get_rank(array); - llvm::Value* llvm_size = builder->CreateAlloca(llvm_utils->getIntType(kind), nullptr); + llvm::Value* llvm_size = builder0.CreateAlloca(llvm_utils->getIntType(kind), nullptr); builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(kind * 8, 1)), llvm_size); llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); - llvm::Value* r = builder->CreateAlloca(llvm_utils->getIntType(4), nullptr); + llvm::Value* r = builder0.CreateAlloca(llvm_utils->getIntType(4), nullptr); builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), r); // head llvm_utils->start_new_block(loophead); diff --git a/src/libasr/codegen/llvm_array_utils.h b/src/libasr/codegen/llvm_array_utils.h index 002d6bdc901..155225b57e5 100644 --- a/src/libasr/codegen/llvm_array_utils.h +++ b/src/libasr/codegen/llvm_array_utils.h @@ -74,7 +74,9 @@ namespace LCompilers { llvm::LLVMContext& context, llvm::IRBuilder<>* builder, LLVMUtils* llvm_utils, - DESCR_TYPE descr_type); + DESCR_TYPE descr_type, + CompilerOptions& co_, + std::vector& heap_arrays_); /* * Checks whether the given ASR::ttype_t* is an @@ -134,7 +136,7 @@ namespace LCompilers { void fill_array_details( llvm::Value* arr, llvm::Type* llvm_data_type, int n_dims, std::vector>& llvm_dims, - bool reserve_data_memory=true) = 0; + llvm::Module* module, bool reserve_data_memory=true) = 0; virtual void fill_array_details( @@ -308,6 +310,9 @@ namespace LCompilers { std::map> tkr2array; + CompilerOptions& co; + std::vector& heap_arrays; + llvm::Value* cmo_convertor_single_element( llvm::Value* arr, std::vector& m_args, int n_args, bool check_for_bounds); @@ -320,7 +325,8 @@ namespace LCompilers { SimpleCMODescriptor(llvm::LLVMContext& _context, llvm::IRBuilder<>* _builder, - LLVMUtils* _llvm_utils); + LLVMUtils* _llvm_utils, CompilerOptions& co_, + std::vector& heap_arrays); virtual bool is_array(ASR::ttype_t* asr_type); @@ -351,7 +357,7 @@ namespace LCompilers { void fill_array_details( llvm::Value* arr, llvm::Type* llvm_data_type, int n_dims, std::vector>& llvm_dims, - bool reserve_data_memory=true); + llvm::Module* module, bool reserve_data_memory=true); virtual void fill_array_details( diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 3c6d269f541..88f77ec9307 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1111,18 +1111,17 @@ namespace LCompilers { ASR::abiType m_abi, bool is_pointer) { llvm::Type* llvm_type = nullptr; - #define handle_llvm_pointers1() if (n_dims == 0 && ASR::is_a(*t2)) { \ - llvm_type = character_type; \ - } else { \ - bool is_pointer_ = (ASR::is_a(*t2) || \ - (ASR::is_a(*t2) && m_abi != ASR::abiType::BindC) ); \ - llvm_type = get_type_from_ttype_t(t2, nullptr, m_storage, is_array_type, \ - is_malloc_array_type, is_list, m_dims, \ - n_dims, a_kind, module, m_abi, is_pointer_); \ - if( !is_pointer_ ) { \ - llvm_type = llvm_type->getPointerTo(); \ - } \ - } \ + #define handle_llvm_pointers1() \ + if (n_dims == 0 && ASR::is_a(*t2)) { \ + llvm_type = character_type; \ + } else { \ + llvm_type = get_type_from_ttype_t(t2, nullptr, m_storage, \ + is_array_type, is_malloc_array_type, is_list, m_dims, \ + n_dims, a_kind, module, m_abi, is_pointer_); \ + if( !is_pointer_ ) { \ + llvm_type = llvm_type->getPointerTo(); \ + } \ + } switch (asr_type->type) { case ASR::ttypeType::Array: { @@ -1206,11 +1205,16 @@ namespace LCompilers { } case (ASR::ttypeType::Pointer) : { ASR::ttype_t *t2 = ASR::down_cast(asr_type)->m_type; + bool is_pointer_ = ( ASR::is_a(*t2) || + (ASR::is_a(*t2) && m_abi != ASR::abiType::BindC) ); + is_malloc_array_type = ASRUtils::is_array(t2); handle_llvm_pointers1() break; } case (ASR::ttypeType::Allocatable) : { ASR::ttype_t *t2 = ASR::down_cast(asr_type)->m_type; + bool is_pointer_ = (ASR::is_a(*t2) + && m_abi != ASR::abiType::BindC); is_malloc_array_type = ASRUtils::is_array(t2); handle_llvm_pointers1() break; @@ -1377,7 +1381,11 @@ namespace LCompilers { // to our new block builder->CreateBr(bb); } +#if LLVM_VERSION_MAJOR >= 16 + fn->insert(fn->end(), bb); +#else fn->getBasicBlockList().push_back(bb); +#endif builder->SetInsertPoint(bb); } diff --git a/src/libasr/compiler_tester/tester.py b/src/libasr/compiler_tester/tester.py index c94ecb2539e..e6c492602cf 100644 --- a/src/libasr/compiler_tester/tester.py +++ b/src/libasr/compiler_tester/tester.py @@ -360,6 +360,8 @@ def tester_main(compiler, single_test): help="Skip LLVM tests") parser.add_argument("--skip-run-with-dbg", action="store_true", help="Skip runtime tests with debugging information enabled") + parser.add_argument("--skip-cpptranslate", action="store_true", + help="Skip tests for ast_openmp that depend on cpptranslate") parser.add_argument("-s", "--sequential", action="store_true", help="Run all tests sequentially") parser.add_argument("--no-color", action="store_true", @@ -380,6 +382,7 @@ def tester_main(compiler, single_test): verbose = args.verbose no_llvm = args.no_llvm skip_run_with_dbg = args.skip_run_with_dbg + skip_cpptranslate = args.skip_cpptranslate global no_color no_color = args.no_color @@ -411,6 +414,7 @@ def tester_main(compiler, single_test): verbose=verbose, no_llvm=no_llvm, skip_run_with_dbg=True, + skip_cpptranslate=True, no_color=True) filtered_tests = [test for test in filtered_tests if 'extrafiles' not in test] @@ -423,6 +427,7 @@ def tester_main(compiler, single_test): verbose=verbose, no_llvm=no_llvm, skip_run_with_dbg=skip_run_with_dbg, + skip_cpptranslate=skip_cpptranslate, no_color=no_color) # run in parallel else: @@ -434,6 +439,7 @@ def tester_main(compiler, single_test): verbose=verbose, no_llvm=no_llvm, skip_run_with_dbg=skip_run_with_dbg, + skip_cpptranslate=skip_cpptranslate, no_color=no_color) with ThreadPoolExecutor() as ex: futures = ex.map(single_tester_partial_args, filtered_tests) diff --git a/src/libasr/dwarf_convert.py b/src/libasr/dwarf_convert.py index cfd5a56e818..56171c8fbda 100755 --- a/src/libasr/dwarf_convert.py +++ b/src/libasr/dwarf_convert.py @@ -64,15 +64,21 @@ def parse_file(self, filename): def parse_debug_line(self): self.line = self.file.readline() + include_dirs_found = True while not self.line.startswith("include_directories"): - self.line = self.file.readline() + if self.line.startswith("file_names"): + include_dirs_found = False + break + else: + self.line = self.file.readline() include_directories = [] - while self.line.startswith("include_directories"): - n, path = re.compile(r"include_directories\[[ ]*(\d+)\] = \"([^\"]+)\"").findall(self.line)[0] - n = int(n) - include_directories.append(IncludeDirectory(n, path)) - self.line = self.file.readline() + if include_dirs_found: + while self.line.startswith("include_directories"): + n, path = re.compile(r"include_directories\[[ ]*(\d+)\] = \"([^\"]+)\"").findall(self.line)[0] + n = int(n) + include_directories.append(IncludeDirectory(n, path)) + self.line = self.file.readline() file_names = [] while self.line.startswith("file_names"): diff --git a/src/libasr/gen_pass.py b/src/libasr/gen_pass.py index c77e4c29fd2..e88114f8e56 100644 --- a/src/libasr/gen_pass.py +++ b/src/libasr/gen_pass.py @@ -32,6 +32,7 @@ "update_array_dim_intrinsic_calls", "replace_where", "unique_symbols", + "insert_deallocate" ] diff --git a/src/libasr/pass/array_op.cpp b/src/libasr/pass/array_op.cpp index 19bd0277876..d3336cee967 100644 --- a/src/libasr/pass/array_op.cpp +++ b/src/libasr/pass/array_op.cpp @@ -95,16 +95,19 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { result_lbound(result_lbound_), result_ubound(result_ubound_), result_inc(result_inc_), op_dims(nullptr), op_n_dims(0), op_expr(nullptr), resultvar2value(resultvar2value_), - realloc_lhs(realloc_lhs_), current_scope(nullptr), result_var(nullptr), - result_type(nullptr) {} + realloc_lhs(realloc_lhs_), current_scope(nullptr), + result_var(nullptr), result_type(nullptr) {} template void create_do_loop(const Location& loc, int var_rank, int result_rank, Vec& idx_vars, Vec& loop_vars, - Vec& idx_vars_value, std::vector& loop_var_indices, - Vec& doloop_body, ASR::expr_t* op_expr, int op_expr_dim_offset, + Vec& idx_vars_value1, Vec& idx_vars_value2, std::vector& loop_var_indices, + Vec& doloop_body, ASR::expr_t* op_expr1, ASR::expr_t* op_expr2, int op_expr_dim_offset, LOOP_BODY loop_body) { - PassUtils::create_idx_vars(idx_vars_value, var_rank, loc, al, current_scope, "_v"); + PassUtils::create_idx_vars(idx_vars_value1, var_rank, loc, al, current_scope, "_v"); + if (op_expr2 != nullptr) { + PassUtils::create_idx_vars(idx_vars_value2, var_rank, loc, al, current_scope, "_u"); + } if( use_custom_loop_params ) { PassUtils::create_idx_vars(idx_vars, loop_vars, loop_var_indices, result_ubound, result_inc, @@ -139,26 +142,47 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { loop_body(); } else { if( var_rank > 0 ) { - ASR::expr_t* idx_lb = PassUtils::get_bound(op_expr, i + op_expr_dim_offset, "lbound", al); + ASR::expr_t* idx_lb = PassUtils::get_bound(op_expr1, i + op_expr_dim_offset, "lbound", al); ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t( - al, loc, idx_vars_value[i+1], idx_lb, nullptr)); + al, loc, idx_vars_value1[i+1], idx_lb, nullptr)); doloop_body.push_back(al, set_to_one); + + if (op_expr2 != nullptr) { + ASR::expr_t* idx_lb2 = PassUtils::get_bound(op_expr2, i + op_expr_dim_offset, "lbound", al); + ASR::stmt_t* set_to_one2 = ASRUtils::STMT(ASR::make_Assignment_t( + al, loc, idx_vars_value2[i+1], idx_lb2, nullptr)); + doloop_body.push_back(al, set_to_one2); + } } doloop_body.push_back(al, doloop); } if( var_rank > 0 ) { ASR::expr_t* inc_expr = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( - al, loc, idx_vars_value[i], ASR::binopType::Add, const_1, int32_type, nullptr)); + al, loc, idx_vars_value1[i], ASR::binopType::Add, const_1, int32_type, nullptr)); ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t( - al, loc, idx_vars_value[i], inc_expr, nullptr)); + al, loc, idx_vars_value1[i], inc_expr, nullptr)); doloop_body.push_back(al, assign_stmt); + + if (op_expr2 != nullptr) { + ASR::expr_t* inc_expr2 = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, loc, idx_vars_value2[i], ASR::binopType::Add, const_1, int32_type, nullptr)); + ASR::stmt_t* assign_stmt2 = ASRUtils::STMT(ASR::make_Assignment_t( + al, loc, idx_vars_value2[i], inc_expr2, nullptr)); + doloop_body.push_back(al, assign_stmt2); + } } doloop = ASRUtils::STMT(ASR::make_DoLoop_t(al, loc, nullptr, head, doloop_body.p, doloop_body.size())); } if( var_rank > 0 ) { - ASR::expr_t* idx_lb = PassUtils::get_bound(op_expr, 1, "lbound", al); - ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, idx_vars_value[0], idx_lb, nullptr)); + ASR::expr_t* idx_lb = PassUtils::get_bound(op_expr1, 1, "lbound", al); + ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, idx_vars_value1[0], idx_lb, nullptr)); pass_result.push_back(al, set_to_one); + + if (op_expr2 != nullptr) { + ASR::expr_t* idx_lb2 = PassUtils::get_bound(op_expr2, 1, "lbound", al); + ASR::stmt_t* set_to_one2 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, idx_vars_value2[0], idx_lb2, nullptr)); + pass_result.push_back(al, set_to_one2); + } } pass_result.push_back(al, doloop); } else if (var_rank == 0) { @@ -239,8 +263,8 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { std::vector loop_var_indices; Vec doloop_body; create_do_loop(loc, var_rank, result_rank, idx_vars, - loop_vars, idx_vars_value, loop_var_indices, doloop_body, - *current_expr, 2, + loop_vars, idx_vars_value, idx_vars_value, loop_var_indices, doloop_body, + *current_expr, nullptr, 2, [=, &idx_vars_value, &idx_vars, &doloop_body]() { ASR::expr_t* ref = nullptr; if( var_rank > 0 ) { @@ -258,7 +282,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { use_custom_loop_params = false; } - #define allocate_result_var(op_arg, op_dims_arg, op_n_dims_arg) if( ASR::is_a(*ASRUtils::expr_type(result_var)) || \ + #define allocate_result_var(op_arg, op_dims_arg, op_n_dims_arg, result_var_created) if( ASR::is_a(*ASRUtils::expr_type(result_var)) || \ ASR::is_a(*ASRUtils::expr_type(result_var)) ) { \ bool is_dimension_empty = false; \ for( int i = 0; i < op_n_dims_arg; i++ ) { \ @@ -302,6 +326,13 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { op_dims = alloc_dims.p; \ op_n_dims = alloc_dims.size(); \ } \ + Vec to_be_deallocated; \ + to_be_deallocated.reserve(al, alloc_args.size()); \ + for( size_t i = 0; i < alloc_args.size(); i++ ) { \ + to_be_deallocated.push_back(al, alloc_args.p[i].m_a); \ + } \ + pass_result.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( \ + al, loc, to_be_deallocated.p, to_be_deallocated.size()))); \ pass_result.push_back(al, ASRUtils::STMT(ASR::make_Allocate_t(al, \ loc, alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr))); \ } @@ -329,7 +360,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { result_var_type, al, current_scope); result_counter += 1; if( allocate ) { - allocate_result_var(arr_expr, arr_expr_dims, arr_expr_n_dims); + allocate_result_var(arr_expr, arr_expr_dims, arr_expr_n_dims, true); } } @@ -339,8 +370,8 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { int result_rank = PassUtils::get_rank(result_var); op_expr = arr_expr; create_do_loop(loc, n_dims, result_rank, idx_vars, - loop_vars, idx_vars_value, loop_var_indices, doloop_body, - op_expr, 2, [=, &arr_expr, &idx_vars, &idx_vars_value, &doloop_body]() { + loop_vars, idx_vars_value, idx_vars_value, loop_var_indices, doloop_body, + op_expr, nullptr, 2, [=, &arr_expr, &idx_vars, &idx_vars_value, &doloop_body]() { ASR::expr_t* ref = PassUtils::create_array_ref(arr_expr, idx_vars_value, al); LCOMPILERS_ASSERT(result_var != nullptr); ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al); @@ -369,6 +400,14 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { replace_vars_helper(x); } + void replace_ArrayBroadcast(ASR::ArrayBroadcast_t* x) { + ASR::expr_t** current_expr_copy_161 = current_expr; + current_expr = &(x->m_array); + replace_expr(x->m_array); + current_expr = current_expr_copy_161; + *current_expr = x->m_array; + } + template void create_do_loop(const Location& loc, int result_rank, Vec& idx_vars, Vec& idx_vars_value, @@ -431,6 +470,46 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { pass_result.push_back(al, doloop); } + template + void create_do_loop_for_const_val(const Location& loc, int result_rank, + Vec& idx_vars, + Vec& loop_vars, std::vector& loop_var_indices, + Vec& doloop_body, LOOP_BODY loop_body) { + if ( use_custom_loop_params ) { + PassUtils::create_idx_vars(idx_vars, loop_vars, loop_var_indices, + result_ubound, result_inc, loc, al, current_scope, "_t"); + } else { + PassUtils::create_idx_vars(idx_vars, result_rank, loc, al, current_scope, "_t"); + loop_vars.from_pointer_n_copy(al, idx_vars.p, idx_vars.size()); + } + + ASR::stmt_t* doloop = nullptr; + for ( int i = (int) loop_vars.size() - 1; i >= 0; i-- ) { + // TODO: Add an If debug node to check if the lower and upper bounds of both the arrays are same. + ASR::do_loop_head_t head; + head.m_v = loop_vars[i]; + if ( use_custom_loop_params ) { + int j = loop_var_indices[i]; + head.m_start = result_lbound[j]; + head.m_end = result_ubound[j]; + head.m_increment = result_inc[j]; + } else { + head.m_start = PassUtils::get_bound(result_var, i + 1, "lbound", al); + head.m_end = PassUtils::get_bound(result_var, i + 1, "ubound", al); + head.m_increment = nullptr; + } + head.loc = head.m_v->base.loc; + doloop_body.reserve(al, 1); + if ( doloop == nullptr ) { + loop_body(); + } else { + doloop_body.push_back(al, doloop); + } + doloop = ASRUtils::STMT(ASR::make_DoLoop_t(al, loc, nullptr, head, doloop_body.p, doloop_body.size())); + } + pass_result.push_back(al, doloop); + } + template void replace_Constant(T* x) { if( !(result_var != nullptr && PassUtils::is_array(result_var) && @@ -441,11 +520,11 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { const Location& loc = x->base.base.loc; int n_dims = PassUtils::get_rank(result_var); - Vec idx_vars, loop_vars, idx_vars_value; + Vec idx_vars, loop_vars; std::vector loop_var_indices; Vec doloop_body; - create_do_loop(loc, n_dims, idx_vars, idx_vars_value, - loop_vars, loop_var_indices, doloop_body, result_var, + create_do_loop_for_const_val(loc, n_dims, idx_vars, + loop_vars, loop_var_indices, doloop_body, [=, &idx_vars, &doloop_body] () { ASR::expr_t* ref = *current_expr; ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al, current_scope); @@ -679,8 +758,13 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { result_var = nullptr; this->replace_expr(x->m_left); ASR::expr_t* left = *current_expr; - left_dims = op_dims; - rank_left = op_n_dims; + if (!is_a(*x->m_left)) { + left_dims = op_dims; + rank_left = op_n_dims; + } else { + left_dims = nullptr; + rank_left = 0; + } current_expr = current_expr_copy_35; ASR::expr_t** current_expr_copy_36 = current_expr; @@ -691,8 +775,13 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { result_var = nullptr; this->replace_expr(x->m_right); ASR::expr_t* right = *current_expr; - right_dims = op_dims; - rank_right = op_n_dims; + if (!is_a(*x->m_right)) { + right_dims = op_dims; + rank_right = op_n_dims; + } else { + right_dims = nullptr; + rank_right = 0; + } current_expr = current_expr_copy_36; op_dims = op_dims_copy; @@ -705,6 +794,17 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { bool new_result_var_created = false; if( rank_left == 0 && rank_right == 0 ) { + if( result_var != nullptr ) { + ASR::stmt_t* auxiliary_assign_stmt_ = nullptr; + std::string name = current_scope->get_unique_name( + "__libasr_created_scalar_auxiliary_variable"); + *current_expr = PassUtils::create_auxiliary_variable_for_expr( + *current_expr, name, al, current_scope, auxiliary_assign_stmt_); + LCOMPILERS_ASSERT(auxiliary_assign_stmt_ != nullptr); + pass_result.push_back(al, auxiliary_assign_stmt_); + resultvar2value[result_var] = *current_expr; + replace_Var(ASR::down_cast(*current_expr)); + } return ; } @@ -728,14 +828,14 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { result_var_type, al, current_scope); result_counter += 1; if( allocate ) { - allocate_result_var(left, left_dims, rank_left); + allocate_result_var(left, left_dims, rank_left, true); } new_result_var_created = true; } *current_expr = result_var; int result_rank = PassUtils::get_rank(result_var); - Vec idx_vars, idx_vars_value, loop_vars; + Vec idx_vars, idx_vars_value_left, idx_vars_value_right, loop_vars; std::vector loop_var_indices; Vec doloop_body; bool use_custom_loop_params_copy = use_custom_loop_params; @@ -743,10 +843,10 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { use_custom_loop_params = false; } create_do_loop(loc, rank_left, result_rank, idx_vars, - loop_vars, idx_vars_value, loop_var_indices, doloop_body, left, 1, - [=, &left, &right, &idx_vars_value, &idx_vars, &doloop_body]() { - ASR::expr_t* ref_1 = PassUtils::create_array_ref(left, idx_vars_value, al, current_scope); - ASR::expr_t* ref_2 = PassUtils::create_array_ref(right, idx_vars_value, al, current_scope); + loop_vars, idx_vars_value_left, idx_vars_value_right, loop_var_indices, doloop_body, left, right, 1, + [=, &left, &right, &idx_vars_value_left, &idx_vars_value_right, &idx_vars, &doloop_body]() { + ASR::expr_t* ref_1 = PassUtils::create_array_ref(left, idx_vars_value_left, al, current_scope); + ASR::expr_t* ref_2 = PassUtils::create_array_ref(right, idx_vars_value_right, al, current_scope); ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al, current_scope); ASR::expr_t* op_el_wise = generate_element_wise_operation(loc, ref_1, ref_2, x); ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, res, op_el_wise, nullptr)); @@ -794,7 +894,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { result_var_type, al, current_scope); result_counter += 1; if( allocate ) { - allocate_result_var(arr_expr, arr_expr_dims, arr_expr_n_dims); + allocate_result_var(arr_expr, arr_expr_dims, arr_expr_n_dims, true); } new_result_var_created = true; } @@ -816,8 +916,8 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { use_custom_loop_params = false; } create_do_loop(loc, n_dims, result_rank, idx_vars, - loop_vars, idx_vars_value, loop_var_indices, doloop_body, - op_expr, 2, [=, &arr_expr, &idx_vars, &idx_vars_value, &doloop_body]() { + loop_vars, idx_vars_value, idx_vars_value, loop_var_indices, doloop_body, + op_expr, nullptr, 2, [=, &arr_expr, &idx_vars, &idx_vars_value, &doloop_body]() { ASR::expr_t* ref = PassUtils::create_array_ref(arr_expr, idx_vars_value, al, current_scope); ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al, current_scope); ASR::expr_t *lexpr = nullptr, *rexpr = nullptr; @@ -920,11 +1020,11 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { if (result_var) { int n_dims = PassUtils::get_rank(result_var); if (n_dims != 0) { - Vec idx_vars, loop_vars, idx_vars_value; + Vec idx_vars, loop_vars; std::vector loop_var_indices; Vec doloop_body; - create_do_loop(loc, n_dims, idx_vars, idx_vars_value, - loop_vars, loop_var_indices, doloop_body, ASRUtils::EXPR((ASR::asr_t*)x), + create_do_loop_for_const_val(loc, n_dims, idx_vars, + loop_vars, loop_var_indices, doloop_body, [=, &idx_vars, &doloop_body] () { ASR::expr_t* ref = ASRUtils::EXPR((ASR::asr_t*)x); ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al, current_scope); @@ -943,9 +1043,22 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { bool result_var_created = false; if( rank_operand > 0 ) { if( result_var == nullptr ) { + bool allocate = false; + ASR::dimension_t *operand_dims = nullptr; + rank_operand = ASRUtils::extract_dimensions_from_ttype( + ASRUtils::expr_type(operand), operand_dims); + ASR::ttype_t* result_var_type = get_result_type(x->m_type, + operand_dims, rank_operand, loc, x->class_type, allocate); + if( allocate ) { + result_var_type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, loc, + ASRUtils::type_get_past_allocatable(result_var_type))); + } result_var = PassUtils::create_var(result_counter, res_prefix, - loc, operand, al, current_scope); + loc, result_var_type, al, current_scope); result_counter += 1; + if( allocate ) { + allocate_result_var(operand, operand_dims, rank_operand, true); + } result_var_created = true; } *current_expr = result_var; @@ -1068,6 +1181,18 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { at_least_one_array = at_least_one_array || array_mask[iarg]; } if (!at_least_one_array) { + if (result_var) { + // Scalar arguments + ASR::stmt_t* auxiliary_assign_stmt_ = nullptr; + std::string name = current_scope->get_unique_name( + "__libasr_created_scalar_auxiliary_variable"); + *current_expr = PassUtils::create_auxiliary_variable_for_expr( + *current_expr, name, al, current_scope, auxiliary_assign_stmt_); + LCOMPILERS_ASSERT(auxiliary_assign_stmt_ != nullptr); + pass_result.push_back(al, auxiliary_assign_stmt_); + resultvar2value[result_var] = *current_expr; + replace_Var(ASR::down_cast(*current_expr)); + } return ; } std::string res_prefix = "_elemental_func_call_res"; @@ -1110,13 +1235,13 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { bool result_var_created = false; if( result_var == nullptr ) { result_var = PassUtils::create_var(result_counter, res_prefix, - loc, x->m_type, al, current_scope); + loc, *current_expr, al, current_scope); result_counter += 1; operand = first_array_operand; ASR::dimension_t* m_dims; int n_dims = ASRUtils::extract_dimensions_from_ttype( ASRUtils::expr_type(first_array_operand), m_dims); - allocate_result_var(operand, m_dims, n_dims); + allocate_result_var(operand, m_dims, n_dims, true); result_var_created = true; } *current_expr = result_var; @@ -1262,7 +1387,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer { ASR::dimension_t* m_dims; int n_dims = ASRUtils::extract_dimensions_from_ttype( ASRUtils::expr_type(operand), m_dims); - allocate_result_var(operand, m_dims, n_dims); + allocate_result_var(operand, m_dims, n_dims, result_var_created); *current_expr = result_var; Vec idx_vars, loop_vars, idx_vars_value; @@ -1437,7 +1562,12 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor(*x.m_value) ) { + resultvar2value[replacer.result_var] = + ASR::down_cast(original_value)->m_array; + } else { + resultvar2value[replacer.result_var] = original_value; + } current_expr = const_cast(&(x.m_value)); this->call_replacer(); current_expr = current_expr_copy_9; @@ -1456,6 +1586,11 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor(*ASRUtils::expr_type(x.m_target)) && + ASR::down_cast(ASRUtils::expr_type(x.m_target))->m_physical_type + == ASR::array_physical_typeType::SIMDArray) { + return; + } if( (ASR::is_a(*ASRUtils::expr_type(x.m_target)) && ASR::is_a(*x.m_value)) || (ASR::is_a(*x.m_value)) ) { @@ -1486,6 +1621,13 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor to_be_deallocated; + to_be_deallocated.reserve(al, vec_alloc.size()); + for( size_t i = 0; i < vec_alloc.size(); i++ ) { + to_be_deallocated.push_back(al, vec_alloc.p[i].m_a); + } + pass_result.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( + al, x.base.base.loc, to_be_deallocated.p, to_be_deallocated.size()))); pass_result.push_back(al, ASRUtils::STMT(ASR::make_Allocate_t( al, x.base.base.loc, vec_alloc.p, 1, nullptr, nullptr, nullptr))); remove_original_statement = false; @@ -1538,6 +1680,163 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor + void create_do_loop(const Location& loc, ASR::expr_t* value_array, int var_rank, int result_rank, + Vec& idx_vars, Vec& loop_vars, + Vec& idx_vars_value, std::vector& loop_var_indices, + Vec& doloop_body, ASR::expr_t* op_expr, int op_expr_dim_offset, + LOOP_BODY loop_body) { + PassUtils::create_idx_vars(idx_vars_value, var_rank, loc, al, current_scope, "_v"); + if( use_custom_loop_params ) { + PassUtils::create_idx_vars(idx_vars, loop_vars, loop_var_indices, + result_ubound, result_inc, + loc, al, current_scope, "_t"); + } else { + PassUtils::create_idx_vars(idx_vars, result_rank, loc, al, current_scope, "_t"); + loop_vars.from_pointer_n_copy(al, idx_vars.p, idx_vars.size()); + } + ASR::stmt_t* doloop = nullptr; + LCOMPILERS_ASSERT(result_rank >= var_rank); + // LCOMPILERS_ASSERT(var_rank == (int) loop_vars.size()); + ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)); + ASR::expr_t* const_1 = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 1, int32_type)); + if (var_rank == (int) loop_vars.size()) { + for( int i = var_rank - 1; i >= 0; i-- ) { + // TODO: Add an If debug node to check if the lower and upper bounds of both the arrays are same. + ASR::do_loop_head_t head; + head.m_v = loop_vars[i]; + if( use_custom_loop_params ) { + int j = loop_var_indices[i]; + head.m_start = result_lbound[j]; + head.m_end = result_ubound[j]; + head.m_increment = result_inc[j]; + } else { + head.m_start = PassUtils::get_bound(value_array, i + 1, "lbound", al); + head.m_end = PassUtils::get_bound(value_array, i + 1, "ubound", al); + head.m_increment = nullptr; + } + head.loc = head.m_v->base.loc; + doloop_body.reserve(al, 1); + if( doloop == nullptr ) { + loop_body(); + } else { + if( var_rank > 0 ) { + ASR::expr_t* idx_lb = PassUtils::get_bound(op_expr, i + op_expr_dim_offset, "lbound", al); + ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t( + al, loc, idx_vars_value[i+1], idx_lb, nullptr)); + doloop_body.push_back(al, set_to_one); + } + doloop_body.push_back(al, doloop); + } + if( var_rank > 0 ) { + ASR::expr_t* inc_expr = ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, loc, idx_vars_value[i], ASR::binopType::Add, const_1, int32_type, nullptr)); + ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t( + al, loc, idx_vars_value[i], inc_expr, nullptr)); + doloop_body.push_back(al, assign_stmt); + } + doloop = ASRUtils::STMT(ASR::make_DoLoop_t(al, loc, nullptr, head, doloop_body.p, doloop_body.size())); + } + if( var_rank > 0 ) { + ASR::expr_t* idx_lb = PassUtils::get_bound(op_expr, 1, "lbound", al); + ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, idx_vars_value[0], idx_lb, nullptr)); + pass_result.push_back(al, set_to_one); + } + pass_result.push_back(al, doloop); + } else if (var_rank == 0) { + for( int i = loop_vars.size() - 1; i >= 0; i-- ) { + // TODO: Add an If debug node to check if the lower and upper bounds of both the arrays are same. + ASR::do_loop_head_t head; + head.m_v = loop_vars[i]; + if( use_custom_loop_params ) { + int j = loop_var_indices[i]; + head.m_start = result_lbound[j]; + head.m_end = result_ubound[j]; + head.m_increment = result_inc[j]; + } else { + head.m_start = PassUtils::get_bound(value_array, i + 1, "lbound", al); + head.m_end = PassUtils::get_bound(value_array, i + 1, "ubound", al); + head.m_increment = nullptr; + } + head.loc = head.m_v->base.loc; + doloop_body.reserve(al, 1); + if( doloop == nullptr ) { + loop_body(); + } else { + doloop_body.push_back(al, doloop); + } + doloop = ASRUtils::STMT(ASR::make_DoLoop_t(al, loc, nullptr, head, doloop_body.p, doloop_body.size())); + } + pass_result.push_back(al, doloop); + } + + } + + void visit_SubroutineCall(const ASR::SubroutineCall_t& x) { + ASR::symbol_t* sym = x.m_original_name; + if (sym && ASR::is_a(*sym)) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(sym); + std::string name = ext_sym->m_name; + std::string module_name = ext_sym->m_module_name; + if (module_name == "lfortran_intrinsic_math" && name == "random_number") { + // iterate over args and check if any of them is an array + ASR::expr_t* arg = nullptr; + for (size_t i=0; i idx_vars, loop_vars, idx_vars_value; + std::vector loop_var_indices; + Vec doloop_body; + create_do_loop(arg->base.loc, arg, var_rank, result_rank, idx_vars, + loop_vars, idx_vars_value, loop_var_indices, doloop_body, + arg, 2, + [=, &idx_vars_value, &idx_vars, &doloop_body]() { + Vec array_index; array_index.reserve(al, idx_vars.size()); + for( size_t i = 0; i < idx_vars.size(); i++ ) { + ASR::array_index_t idx; + idx.m_left = nullptr; + idx.m_right = idx_vars_value[i]; + idx.m_step = nullptr; + idx.loc = idx_vars_value[i]->base.loc; + array_index.push_back(al, idx); + } + ASR::expr_t* array_item = ASRUtils::EXPR(ASR::make_ArrayItem_t(al, x.base.base.loc, + arg, array_index.p, array_index.size(), + ASRUtils::type_get_past_array(ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(ASRUtils::expr_type(arg)))), + ASR::arraystorageType::ColMajor, nullptr)); + Vec ref_args; ref_args.reserve(al, 1); + ASR::call_arg_t ref_arg; ref_arg.loc = array_item->base.loc; ref_arg.m_value = array_item; + ref_args.push_back(al, ref_arg); + ASR::stmt_t* subroutine_call = ASRUtils::STMT(ASRUtils::make_SubroutineCall_t_util(al, x.base.base.loc, + x.m_name, x.m_original_name, ref_args.p, ref_args.n, nullptr, nullptr, false)); + doloop_body.push_back(al, subroutine_call); + }); + remove_original_statement = true; + } + } + } + for (size_t i=0; i(&(x.m_array)); + call_replacer(); + current_expr = current_expr_copy_269; + if( x.m_array ) { + visit_expr(*x.m_array); + } + } + }; void pass_replace_array_op(Allocator &al, ASR::TranslationUnit_t &unit, diff --git a/src/libasr/pass/implied_do_loops.cpp b/src/libasr/pass/implied_do_loops.cpp index 6b2c4244149..c4254dc55de 100644 --- a/src/libasr/pass/implied_do_loops.cpp +++ b/src/libasr/pass/implied_do_loops.cpp @@ -56,19 +56,34 @@ class ReplaceArrayConstant: public ASR::BaseExprReplacer { make_ConstantWithKind(make_IntegerConstant_t, make_Integer_t, 1, 4, loc), loc); } int const_elements = 0; + ASR::expr_t* implied_doloop_size_ = nullptr; for( size_t i = 0; i < implied_doloop->n_values; i++ ) { if( ASR::is_a(*implied_doloop->m_values[i]) ) { - ASR::expr_t* implied_doloop_size_ = get_ImpliedDoLoop_size( - ASR::down_cast(implied_doloop->m_values[i])); - implied_doloop_size = builder.ElementalMul(implied_doloop_size_, implied_doloop_size, loc); + if( implied_doloop_size_ == nullptr ) { + implied_doloop_size_ = get_ImpliedDoLoop_size( + ASR::down_cast(implied_doloop->m_values[i])); + } else { + implied_doloop_size_ = builder.ElementalAdd(get_ImpliedDoLoop_size( + ASR::down_cast(implied_doloop->m_values[i])), + implied_doloop_size_, loc); + } } else { const_elements += 1; } } - if( const_elements > 0 ) { - implied_doloop_size = builder.ElementalAdd( - make_ConstantWithKind(make_IntegerConstant_t, make_Integer_t, const_elements, 4, loc), - implied_doloop_size, loc); + if( const_elements > 1 ) { + if( implied_doloop_size_ == nullptr ) { + implied_doloop_size_ = make_ConstantWithKind(make_IntegerConstant_t, + make_Integer_t, const_elements, 4, loc); + } else { + implied_doloop_size_ = builder.ElementalAdd( + make_ConstantWithKind(make_IntegerConstant_t, + make_Integer_t, const_elements, 4, loc), + implied_doloop_size_, loc); + } + } + if( implied_doloop_size_ ) { + implied_doloop_size = builder.ElementalMul(implied_doloop_size_, implied_doloop_size, loc); } return implied_doloop_size; } @@ -165,7 +180,7 @@ class ReplaceArrayConstant: public ASR::BaseExprReplacer { } } ASR::expr_t* constant_size_asr = nullptr; - if (constant_size == 0) { + if (constant_size == 0 && array_size == nullptr) { constant_size = ASRUtils::get_fixed_size_of_array(x->m_type); } if( constant_size > 0 ) { @@ -238,6 +253,13 @@ class ReplaceArrayConstant: public ASR::BaseExprReplacer { arg.loc = result_var->base.loc; arg.m_a = result_var; alloc_args.push_back(al, arg); + Vec to_be_deallocated; + to_be_deallocated.reserve(al, alloc_args.size()); + for( size_t i = 0; i < alloc_args.size(); i++ ) { + to_be_deallocated.push_back(al, alloc_args.p[i].m_a); + } + pass_result.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( + al, loc, to_be_deallocated.p, to_be_deallocated.size()))); ASR::stmt_t* allocate_stmt = ASRUtils::STMT(ASR::make_Allocate_t( al, loc, alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)); pass_result.push_back(al, allocate_stmt); @@ -247,6 +269,13 @@ class ReplaceArrayConstant: public ASR::BaseExprReplacer { arg.loc = result_var_copy->base.loc; arg.m_a = result_var_copy; alloc_args.push_back(al, arg); + Vec to_be_deallocated; + to_be_deallocated.reserve(al, alloc_args.size()); + for( size_t i = 0; i < alloc_args.size(); i++ ) { + to_be_deallocated.push_back(al, alloc_args.p[i].m_a); + } + pass_result.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( + al, loc, to_be_deallocated.p, to_be_deallocated.size()))); ASR::stmt_t* allocate_stmt = ASRUtils::STMT(ASR::make_Allocate_t( al, loc, alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)); pass_result.push_back(al, allocate_stmt); @@ -274,6 +303,13 @@ class ReplaceArrayConstant: public ASR::BaseExprReplacer { } } + void replace_ArrayBroadcast(ASR::ArrayBroadcast_t* x) { + ASR::expr_t** current_expr_copy_161 = current_expr; + current_expr = &(x->m_array); + replace_expr(x->m_array); + current_expr = current_expr_copy_161; + } + }; class ArrayConstantVisitor : public ASR::CallReplacerOnExpressionsVisitor @@ -369,6 +405,148 @@ class ArrayConstantVisitor : public ASR::CallReplacerOnExpressionsVisitor + ASR::asr_t* create_array_constant(const T& x, ASR::expr_t* value) { + // wrap the implied do loop in an array constant + Vec args; + args.reserve(al, 1); + args.push_back(al, value); + + Vec dim; + dim.reserve(al, 1); + + ASR::dimension_t d; + d.loc = value->base.loc; + + ASR::ttype_t *int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4)); + ASR::expr_t* one = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)); + + d.m_start = one; + d.m_length = one; + + dim.push_back(al, d); + + ASR::ttype_t* array_type = ASRUtils::TYPE(ASR::make_Array_t(al, value->base.loc, ASRUtils::expr_type(value), dim.p, dim.size(), ASR::array_physical_typeType::FixedSizeArray)); + ASR::asr_t* array_constant = ASR::make_ArrayConstant_t(al, value->base.loc, + args.p, args.n, array_type, ASR::arraystorageType::ColMajor); + return array_constant; + } + + void visit_Print(const ASR::Print_t &x) { + /* + integer :: i + print *, (i, i=1, 10) + + TO + + integer :: i + print *, [(i, i=1, 10)] + */ + ASR::Print_t* print_stmt = const_cast(&x); + for(size_t i = 0; i < x.n_values; i++) { + ASR::expr_t* value = x.m_values[i]; + if (ASR::is_a(*value)) { + ASR::asr_t* array_constant = create_array_constant(x, value); + print_stmt->m_values[i] = ASRUtils::EXPR(array_constant); + + replacer.result_var = value; + resultvar2value[replacer.result_var] = ASRUtils::EXPR(array_constant); + ASR::expr_t** current_expr_copy_9 = current_expr; + current_expr = const_cast(&(print_stmt->m_values[i])); + this->call_replacer(); + current_expr = current_expr_copy_9; + if( !remove_original_statement ) { + this->visit_expr(*print_stmt->m_values[i]); + } + } else { + ASR::expr_t** current_expr_copy_9 = current_expr; + current_expr = const_cast(&(print_stmt->m_values[i])); + this->call_replacer(); + current_expr = current_expr_copy_9; + if( !remove_original_statement ) { + this->visit_expr(*print_stmt->m_values[i]); + } + } + } + } + + void visit_StringFormat(const ASR::StringFormat_t &x) { + /* + integer :: i + write(*, '(i)') (i, i=1, 10) + + TO + + integer :: i + write(*, '(i)') [(i, i=1, 10)] + */ + ASR::StringFormat_t* string_format_stmt = const_cast(&x); + for(size_t i = 0; i < x.n_args; i++) { + ASR::expr_t* value = x.m_args[i]; + if (ASR::is_a(*value)) { + ASR::asr_t* array_constant = create_array_constant(x, value); + string_format_stmt->m_args[i] = ASRUtils::EXPR(array_constant); + + replacer.result_var = value; + resultvar2value[replacer.result_var] = ASRUtils::EXPR(array_constant); + ASR::expr_t** current_expr_copy_9 = current_expr; + current_expr = const_cast(&(string_format_stmt->m_args[i])); + this->call_replacer(); + current_expr = current_expr_copy_9; + if( !remove_original_statement ) { + this->visit_expr(*string_format_stmt->m_args[i]); + } + } else { + ASR::expr_t** current_expr_copy_9 = current_expr; + current_expr = const_cast(&(string_format_stmt->m_args[i])); + this->call_replacer(); + current_expr = current_expr_copy_9; + if( !remove_original_statement ) { + this->visit_expr(*string_format_stmt->m_args[i]); + } + } + } + } + + void visit_FileWrite(const ASR::FileWrite_t &x) { + /* + integer :: i + write(*,*) (i, i=1, 10) + + TO + + integer :: i + write(*,*) [(i, i=1, 10)] + */ + ASR::FileWrite_t* write_stmt = const_cast(&x); + for(size_t i = 0; i < x.n_values; i++) { + ASR::expr_t* value = x.m_values[i]; + if (ASR::is_a(*value)) { + ASR::asr_t* array_constant = create_array_constant(x, value); + + write_stmt->m_values[i] = ASRUtils::EXPR(array_constant); + + replacer.result_var = value; + resultvar2value[replacer.result_var] = ASRUtils::EXPR(array_constant); + ASR::expr_t** current_expr_copy_9 = current_expr; + current_expr = const_cast(&(write_stmt->m_values[i])); + this->call_replacer(); + current_expr = current_expr_copy_9; + if( !remove_original_statement ) { + this->visit_expr(*write_stmt->m_values[i]); + } + } else { + ASR::expr_t** current_expr_copy_9 = current_expr; + current_expr = const_cast(&(write_stmt->m_values[i])); + this->call_replacer(); + current_expr = current_expr_copy_9; + if( !remove_original_statement ) { + this->visit_expr(*write_stmt->m_values[i]); + } + } + } + } + void visit_CPtrToPointer(const ASR::CPtrToPointer_t& x) { if (x.m_shape) { ASR::expr_t** current_expr_copy = current_expr; @@ -380,6 +558,16 @@ class ArrayConstantVisitor : public ASR::CallReplacerOnExpressionsVisitor(&(x.m_array)); + call_replacer(); + current_expr = current_expr_copy_269; + if( x.m_array ) { + visit_expr(*x.m_array); + } + } + }; void pass_replace_implied_do_loops(Allocator &al, diff --git a/src/libasr/pass/inline_function_calls.cpp b/src/libasr/pass/inline_function_calls.cpp index 99c513fd962..6cf0ae3b82e 100644 --- a/src/libasr/pass/inline_function_calls.cpp +++ b/src/libasr/pass/inline_function_calls.cpp @@ -35,65 +35,23 @@ itself. This helps in avoiding function call overhead in the backend code. c = a + 5 */ -class InlineFunctionCallVisitor : public PassUtils::PassVisitor +class FixSymbolsVisitor: public ASR::BaseWalkVisitor { private: - std::string rl_path; - - ASR::expr_t* function_result_var; + SymbolTable*& current_routine_scope; + SymbolTable*& current_scope; - bool from_inline_function_call, inlining_function; - bool fixed_duplicated_expr_stmt; - bool is_fast; + bool& fixed_duplicated_expr_stmt; - // Stores the local variables or/and Block symbol corresponding to the ones - // present in function symbol table. - std::map arg2value; - - std::string current_routine; - - bool inline_external_symbol_calls; - - - ASRUtils::ExprStmtDuplicator node_duplicator; - - SymbolTable* current_routine_scope; - ASRUtils::LabelGenerator* label_generator; - ASR::symbol_t* empty_block; - ASRUtils::ReplaceReturnWithGotoVisitor return_replacer; + std::map& arg2value; public: - bool function_inlined; - - InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_, - bool inline_external_symbol_calls_, bool is_fast_) - : PassVisitor(al_, nullptr), - rl_path(rl_path_), function_result_var(nullptr), - from_inline_function_call(false), inlining_function(false), fixed_duplicated_expr_stmt(false), - is_fast(is_fast_), - current_routine(""), inline_external_symbol_calls(inline_external_symbol_calls_), - node_duplicator(al_), current_routine_scope(nullptr), - label_generator(ASRUtils::LabelGenerator::get_instance()), - empty_block(nullptr), return_replacer(al_, 0), - function_inlined(false) - { - pass_result.reserve(al, 1); - } - - void configure_node_duplicator(bool allow_procedure_calls_) { - node_duplicator.allow_procedure_calls = allow_procedure_calls_; - } - - void visit_Function(const ASR::Function_t &x) { - // FIXME: this is a hack, we need to pass in a non-const `x`, - // which requires to generate a TransformVisitor. - ASR::Function_t &xx = const_cast(x); - current_routine = std::string(xx.m_name); - PassUtils::PassVisitor::visit_Function(x); - current_routine.clear(); - } + FixSymbolsVisitor(SymbolTable*& current_routine_scope_, SymbolTable*& current_scope_, + bool& fixed_duplicated_expr_stmt, std::map& arg2value_): + current_routine_scope(current_routine_scope_), current_scope(current_scope_), + fixed_duplicated_expr_stmt(fixed_duplicated_expr_stmt), arg2value(arg2value_) {} // If anything is not local to a function being inlined // then do not inline the function by setting @@ -116,7 +74,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor(x); - ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(x.m_v); + ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(xx.m_v); if (ASR::is_a(*sym)) { replace_symbol(sym, ASR::Variable_t, xx.m_v); } else { @@ -124,9 +82,62 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor(x); - replace_symbol(x.m_m, ASR::Block_t, xx.m_m); + replace_symbol(xx.m_m, ASR::Block_t, xx.m_m); + } + +}; + +class InlineFunctionCall : public ASR::BaseExprReplacer +{ +private: + + Allocator& al; + std::string rl_path; + + ASR::expr_t* function_result_var; + + bool& from_inline_function_call, inlining_function; + bool fixed_duplicated_expr_stmt; + bool is_fast; + + // Stores the local variables or/and Block symbol corresponding to the ones + // present in function symbol table. + std::map arg2value; + + std::string current_routine; + + bool inline_external_symbol_calls; + + + ASRUtils::ExprStmtDuplicator node_duplicator; + + SymbolTable* current_routine_scope; + ASRUtils::LabelGenerator* label_generator; + ASR::symbol_t* empty_block; + ASRUtils::ReplaceReturnWithGotoVisitor return_replacer; + Vec& pass_result; + +public: + + SymbolTable* current_scope; + FixSymbolsVisitor fix_symbols_visitor; + bool function_inlined; + + InlineFunctionCall(Allocator &al_, const std::string& rl_path_, + bool inline_external_symbol_calls_, bool is_fast_, + Vec& pass_result_, bool& from_inline_function_call_, + std::string& current_routine_): al(al_), rl_path(rl_path_), function_result_var(nullptr), + from_inline_function_call(from_inline_function_call_), inlining_function(false), fixed_duplicated_expr_stmt(false), + is_fast(is_fast_), current_routine(current_routine_), inline_external_symbol_calls(inline_external_symbol_calls_), + node_duplicator(al_), current_routine_scope(nullptr), label_generator(ASRUtils::LabelGenerator::get_instance()), + empty_block(nullptr), return_replacer(al_, 0), pass_result(pass_result_), current_scope(nullptr), + fix_symbols_visitor(current_routine_scope, current_scope, fixed_duplicated_expr_stmt, arg2value), + function_inlined(false) {} + + void configure_node_duplicator(bool allow_procedure_calls_) { + node_duplicator.allow_procedure_calls = allow_procedure_calls_; } void set_empty_block(SymbolTable* scope, const Location& loc) { @@ -147,7 +158,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorerase_symbol("~empty_block"); } - void visit_FunctionCall(const ASR::FunctionCall_t& x) { + void replace_FunctionCall(ASR::FunctionCall_t* x) { // If this node is visited by any other visitor // or it is being visited while inlining another function call // then return. To ensure that only one function call is inlined @@ -157,8 +168,8 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor(*x.m_name) ) { - ASR::ExternalSymbol_t* called_sym_ext = ASR::down_cast(x.m_name); + if( ASR::is_a(*x->m_name) ) { + ASR::ExternalSymbol_t* called_sym_ext = ASR::down_cast(x->m_name); ASR::symbol_t* f_sym = ASRUtils::symbol_get_past_external(called_sym_ext->m_external); ASR::Function_t* f = ASR::down_cast(f_sym); @@ -167,12 +178,11 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorm_name; // TODO: Handle later // ASR::symbol_t* called_sym_original = x.m_original_name; - ASR::FunctionCall_t& xx = const_cast(x); std::string called_sym_name = std::string(called_sym_ext->m_name); std::string new_sym_name_str = current_scope->get_unique_name(called_sym_name, false); char* new_sym_name = s2c(al, new_sym_name_str); @@ -185,15 +195,25 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorm_name, ASR::accessType::Private)); current_scope->add_symbol(new_sym_name_str, new_sym); } - xx.m_name = current_scope->get_symbol(new_sym_name_str); + x->m_name = current_scope->get_symbol(new_sym_name_str); } - for( size_t i = 0; i < x.n_args; i++ ) { - visit_expr(*x.m_args[i].m_value); + for( size_t i = 0; i < x->n_args; i++ ) { + fix_symbols_visitor.visit_expr(*x->m_args[i].m_value); } return ; } + // Avoid inlining if function call accepts a callback argument + for( size_t i = 0; i < x->n_args; i++ ) { + if( x->m_args[i].m_value && + ASR::is_a( + *ASRUtils::type_get_past_pointer( + ASRUtils::expr_type(x->m_args[i].m_value))) ) { + return ; + } + } + // Clear up any local variables present in arg2value map // due to inlining other function calls arg2value.clear(); @@ -205,11 +225,11 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorm_name; if( !ASR::is_a(*routine) ) { if( ASR::is_a(*routine) && inline_external_symbol_calls) { - routine = ASRUtils::symbol_get_past_external(x.m_name); + routine = ASRUtils::symbol_get_past_external(x->m_name); if( !ASR::is_a(*routine) ) { return ; } @@ -242,7 +262,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorn_args ) { func_margs_i = func->m_args[i]; - x_m_args_i = x.m_args[i].m_value; + x_m_args_i = x->m_args[i].m_value; } else { func_margs_i = func->m_return_var; x_m_args_i = nullptr; @@ -253,7 +273,9 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor(func_margs_i); // TODO: Expand to other symbol types, Function, Subroutine, ExternalSymbol - if( !ASR::is_a(*(arg_var->m_v)) ) { + if( !ASR::is_a(*(arg_var->m_v)) || + ASRUtils::is_character(*ASRUtils::symbol_type(arg_var->m_v)) || + ASRUtils::is_array(ASRUtils::symbol_type(arg_var->m_v)) ) { arg2value.clear(); return ; } @@ -288,7 +310,11 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorbase.base.loc); continue; } - if( !ASR::is_a(*itr.second) ) { + if( !ASR::is_a(*itr.second) || + ASRUtils::is_character(*ASRUtils::symbol_type(itr.second)) || + ASRUtils::is_array(ASRUtils::symbol_type(itr.second)) || + ASR::is_a(*ASRUtils::symbol_type(itr.second)) || + ASR::is_a(*ASRUtils::symbol_type(itr.second)) ) { arg2value.clear(); return ; } @@ -337,7 +363,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorn_body && success; i++ ) { fixed_duplicated_expr_stmt = true; - visit_stmt(*func_copy[i]); + fix_symbols_visitor.visit_stmt(*func_copy[i]); success = success && fixed_duplicated_expr_stmt; } if( success ) { set_empty_block(current_scope, func->base.base.loc); uint64_t block_call_label = label_generator->get_unique_label(); - ASR::stmt_t* block_call = ASRUtils::STMT(ASR::make_BlockCall_t(al, x.base.base.loc, - block_call_label, empty_block)); + ASR::stmt_t* block_call = ASRUtils::STMT(ASR::make_BlockCall_t( + al, x->base.base.loc, block_call_label, empty_block)); label_generator->add_node_with_unique_label((ASR::asr_t*) block_call, block_call_label); return_replacer.set_goto_label(block_call_label); @@ -417,63 +443,94 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor +{ +private: + + Allocator& al; - void visit_UnsignedIntegerBinOp(const ASR::UnsignedIntegerBinOp_t& x) { - handle_BinOp(x); + bool from_inline_function_call; + std::string current_routine; + Vec* parent_body; + Vec pass_result; + + InlineFunctionCall replacer; + +public: + + bool function_inlined; + + InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_, + bool inline_external_symbol_calls_, bool is_fast_): + al(al_), current_routine(""), parent_body(nullptr), + replacer(al_, rl_path_, inline_external_symbol_calls_, is_fast_, + pass_result, from_inline_function_call, current_routine) { + pass_result.reserve(al, 1); } - void visit_RealBinOp(const ASR::RealBinOp_t& x) { - handle_BinOp(x); + void configure_node_duplicator(bool allow_procedure_calls_) { + replacer.configure_node_duplicator(allow_procedure_calls_); } - void visit_ComplexBinOp(const ASR::ComplexBinOp_t& x) { - handle_BinOp(x); + void call_replacer() { + replacer.current_expr = current_expr; + replacer.current_scope = current_scope; + replacer.replace_expr(*current_expr); } - void visit_LogicalBinOp(const ASR::LogicalBinOp_t& x) { - handle_BinOp(x); + void visit_Function(const ASR::Function_t &x) { + // FIXME: this is a hack, we need to pass in a non-const `x`, + // which requires to generate a TransformVisitor. + ASR::Function_t &xx = const_cast(x); + current_routine = std::string(xx.m_name); + ASR::CallReplacerOnExpressionsVisitor::visit_Function(x); + current_routine.clear(); } - template - void handle_BinOp(const T& x) { - T& xx = const_cast(x); + void visit_Assignment(const ASR::Assignment_t& x) { from_inline_function_call = true; - function_result_var = nullptr; - visit_expr(*x.m_left); - if( function_result_var ) { - xx.m_left = function_result_var; - } - function_result_var = nullptr; - visit_expr(*x.m_right); - if( function_result_var ) { - xx.m_right = function_result_var; - } - function_result_var = nullptr; + ASR::CallReplacerOnExpressionsVisitor::visit_Assignment(x); from_inline_function_call = false; } - void visit_Assignment(const ASR::Assignment_t& x) { - from_inline_function_call = true; - retain_original_stmt = true; - ASR::Assignment_t& xx = const_cast(x); - function_result_var = nullptr; - visit_expr(*x.m_target); - function_result_var = nullptr; - visit_expr(*x.m_value); - if( function_result_var ) { - xx.m_value = function_result_var; + void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) { + Vec body; + body.reserve(al, n_body); + if( parent_body ) { + for (size_t j=0; j < pass_result.size(); j++) { + parent_body->push_back(al, pass_result[j]); + } } - function_result_var = nullptr; - from_inline_function_call = false; + for (size_t i=0; i* parent_body_copy = parent_body; + parent_body = &body; + visit_stmt(*m_body[i]); + parent_body = parent_body_copy; + for (size_t j=0; j < pass_result.size(); j++) { + body.push_back(al, pass_result[j]); + } + body.push_back(al, m_body[i]); + } + m_body = body.p; + n_body = body.size(); + pass_result.n = 0; + } + + void visit_Character(const ASR::Character_t& /*x*/) { + } }; diff --git a/src/libasr/pass/insert_deallocate.cpp b/src/libasr/pass/insert_deallocate.cpp new file mode 100644 index 00000000000..c3ddb204dd1 --- /dev/null +++ b/src/libasr/pass/insert_deallocate.cpp @@ -0,0 +1,67 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace LCompilers { + +class InsertDeallocate: public ASR::CallReplacerOnExpressionsVisitor +{ + private: + + Allocator& al; + + public: + + InsertDeallocate(Allocator& al_) : al(al_) {} + + template + void visit_Symbol(const T& x) { + Vec to_be_deallocated; + to_be_deallocated.reserve(al, 1); + for( auto& itr: x.m_symtab->get_scope() ) { + if( ASR::is_a(*itr.second) && + ASR::is_a(*ASRUtils::symbol_type(itr.second)) && + ASRUtils::is_array(ASRUtils::symbol_type(itr.second)) && + ASRUtils::symbol_intent(itr.second) == ASRUtils::intent_local ) { + to_be_deallocated.push_back(al, ASRUtils::EXPR( + ASR::make_Var_t(al, x.base.base.loc, itr.second))); + } + } + if( to_be_deallocated.size() > 0 ) { + T& xx = const_cast(x); + Vec body; + body.from_pointer_n_copy(al, xx.m_body, xx.n_body); + body.push_back(al, ASRUtils::STMT(ASR::make_ImplicitDeallocate_t( + al, x.base.base.loc, to_be_deallocated.p, to_be_deallocated.size()))); + xx.m_body = body.p; + xx.n_body = body.size(); + } + } + + void visit_Function(const ASR::Function_t& x) { + visit_Symbol(x); + ASR::CallReplacerOnExpressionsVisitor::visit_Function(x); + } + + void visit_Program(const ASR::Program_t& x) { + visit_Symbol(x); + ASR::CallReplacerOnExpressionsVisitor::visit_Program(x); + } + +}; + +void pass_insert_deallocate(Allocator &al, ASR::TranslationUnit_t &unit, + const PassOptions &/*pass_options*/) { + InsertDeallocate v(al); + v.visit_TranslationUnit(unit); +} + + +} // namespace LCompilers diff --git a/src/libasr/pass/insert_deallocate.h b/src/libasr/pass/insert_deallocate.h new file mode 100644 index 00000000000..206b8bdfe8a --- /dev/null +++ b/src/libasr/pass/insert_deallocate.h @@ -0,0 +1,14 @@ +#ifndef LIBASR_PASS_INSERT_DEALLOCATE_H +#define LIBASR_PASS_INSERT_DEALLOCATE_H + +#include +#include + +namespace LCompilers { + + void pass_insert_deallocate(Allocator &al, ASR::TranslationUnit_t &unit, + const PassOptions &pass_options); + +} // namespace LCompilers + +#endif // LIBASR_PASS_INSERT_DEALLOCATE_H diff --git a/src/libasr/pass/instantiate_template.cpp b/src/libasr/pass/instantiate_template.cpp index 14144170c31..65b8145efcc 100644 --- a/src/libasr/pass/instantiate_template.cpp +++ b/src/libasr/pass/instantiate_template.cpp @@ -7,6 +7,150 @@ namespace LCompilers { +class SymbolRenamer : public ASR::BaseExprStmtDuplicator +{ +public: + SymbolTable* current_scope; + std::map &type_subs; + std::string new_sym_name; + + SymbolRenamer(Allocator& al, std::map& type_subs, + SymbolTable* current_scope, std::string new_sym_name): + BaseExprStmtDuplicator(al), + current_scope{current_scope}, + type_subs{type_subs}, + new_sym_name{new_sym_name} + {} + + ASR::symbol_t* rename_symbol(ASR::symbol_t *x) { + switch (x->type) { + case (ASR::symbolType::Variable): { + ASR::Variable_t *v = ASR::down_cast(x); + return rename_Variable(v); + } + case (ASR::symbolType::Function): { + if (current_scope->get_symbol(new_sym_name)) { + return current_scope->get_symbol(new_sym_name); + } + ASR::Function_t *f = ASR::down_cast(x); + return rename_Function(f); + } + default: { + std::string sym_name = ASRUtils::symbol_name(x); + throw new LCompilersException("Symbol renaming not supported " + " for " + sym_name); + } + } + } + + ASR::symbol_t* rename_Variable(ASR::Variable_t *x) { + ASR::ttype_t *t = x->m_type; + ASR::dimension_t* tp_m_dims = nullptr; + int tp_n_dims = ASRUtils::extract_dimensions_from_ttype(t, tp_m_dims); + + if (ASR::is_a(*t)) { + ASR::TypeParameter_t *tp = ASR::down_cast(t); + if (type_subs.find(tp->m_param) != type_subs.end()) { + t = ASRUtils::make_Array_t_util(al, tp->base.base.loc, + ASRUtils::duplicate_type(al, type_subs[tp->m_param]), + tp_m_dims, tp_n_dims); + } else { + t = ASRUtils::make_Array_t_util(al, tp->base.base.loc, ASRUtils::TYPE( + ASR::make_TypeParameter_t(al, tp->base.base.loc, + s2c(al, new_sym_name))), tp_m_dims, tp_n_dims); + type_subs[tp->m_param] = t; + } + } + + if (current_scope->get_symbol(new_sym_name)) { + return current_scope->get_symbol(new_sym_name); + } + + ASR::symbol_t* new_v = ASR::down_cast(ASR::make_Variable_t( + al, x->base.base.loc, + current_scope, s2c(al, new_sym_name), x->m_dependencies, + x->n_dependencies, x->m_intent, x->m_symbolic_value, + x->m_value, x->m_storage, t, x->m_type_declaration, + x->m_abi, x->m_access, x->m_presence, x->m_value_attr)); + + current_scope->add_symbol(new_sym_name, new_v); + + return new_v; + } + + ASR::symbol_t* rename_Function(ASR::Function_t *x) { + ASR::FunctionType_t* ftype = ASR::down_cast(x->m_function_signature); + + SymbolTable* parent_scope = current_scope; + current_scope = al.make_new(parent_scope); + + Vec args; + args.reserve(al, x->n_args); + for (size_t i=0; in_args; i++) { + ASR::expr_t *new_arg = duplicate_expr(x->m_args[i]); + args.push_back(al, new_arg); + } + + ASR::expr_t *new_return_var_ref = nullptr; + if (x->m_return_var != nullptr) { + new_return_var_ref = duplicate_expr(x->m_return_var); + } + + ASR::symbol_t *new_f = ASR::down_cast(ASRUtils::make_Function_t_util( + al, x->base.base.loc, current_scope, s2c(al, new_sym_name), x->m_dependencies, + x->n_dependencies, args.p, args.size(), nullptr, 0, new_return_var_ref, ftype->m_abi, + x->m_access, ftype->m_deftype, ftype->m_bindc_name, ftype->m_elemental, + ftype->m_pure, ftype->m_module, ftype->m_inline, ftype->m_static, ftype->m_restrictions, + ftype->n_restrictions, ftype->m_is_restriction, x->m_deterministic, x->m_side_effect_free)); + + parent_scope->add_symbol(new_sym_name, new_f); + current_scope = parent_scope; + + return new_f; + } + + ASR::asr_t* duplicate_Var(ASR::Var_t *x) { + std::string sym_name = ASRUtils::symbol_name(x->m_v); + ASR::symbol_t* sym = duplicate_symbol(x->m_v); + return ASR::make_Var_t(al, x->base.base.loc, sym); + } + + ASR::symbol_t* duplicate_symbol(ASR::symbol_t *x) { + ASR::symbol_t* new_symbol = nullptr; + switch (x->type) { + case ASR::symbolType::Variable: { + new_symbol = duplicate_Variable(ASR::down_cast(x)); + break; + } + default: { + throw LCompilersException("Unsupported symbol for symbol renaming"); + } + } + return new_symbol; + } + + ASR::symbol_t* duplicate_Variable(ASR::Variable_t *x) { + ASR::ttype_t *t = x->m_type; + + if (ASR::is_a(*t)) { + ASR::TypeParameter_t *tp = ASR::down_cast(t); + LCOMPILERS_ASSERT(type_subs.find(tp->m_param) != type_subs.end()); + t = ASRUtils::duplicate_type(al, type_subs[tp->m_param]); + } + + ASR::symbol_t* new_v = ASR::down_cast(ASR::make_Variable_t( + al, x->base.base.loc, current_scope, x->m_name, x->m_dependencies, + x->n_dependencies, x->m_intent, x->m_symbolic_value, + x->m_value, x->m_storage, t, x->m_type_declaration, + x->m_abi, x->m_access, x->m_presence, x->m_value_attr)); + + current_scope->add_symbol(x->m_name, new_v); + + return new_v; + } + +}; + class SymbolInstantiator : public ASR::BaseExprStmtDuplicator { public: @@ -44,8 +188,8 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatorbase.loc); + throw new LCompilersException("Instantiation of " + sym_name + + " symbol is not supported"); }; } } @@ -195,15 +339,6 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatoradd_symbol(new_sym_name, t); context_map[x->m_name] = new_sym_name; - /* - for (auto const &sym_pair: x->m_symtab->get_scope()) { - ASR::symbol_t *sym = sym_pair.second; - if (ASR::is_a(*sym)) { - ASR::symbol_t *new_sym = duplicate_ClassProcedure(sym); - current_scope->add_symbol(ASRUtils::symbol_name(new_sym), new_sym); - } - } - */ for (auto const &sym_pair: x->m_symtab->get_scope()) { if (ASR::is_a(*sym_pair.second)) { duplicate_symbol(sym_pair.second); @@ -324,7 +459,8 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatorm_type); return ASRUtils::make_ArrayItem_t_util(al, x->base.base.loc, m_v, args.p, x->n_args, - ASRUtils::type_get_past_allocatable(type), x->m_storage_format, m_value); + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(type)), x->m_storage_format, m_value); } ASR::asr_t* duplicate_ArrayConstant(ASR::ArrayConstant_t *x) { @@ -412,11 +548,11 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatorresolve_symbol(context_map[call_name]); } else if (ASRUtils::is_generic_function(name)) { ASR::symbol_t *search_sym = current_scope->resolve_symbol(call_name); - if (search_sym != nullptr) { + if (search_sym != nullptr && ASR::is_a(*search_sym)) { name = search_sym; } else { ASR::symbol_t* name2 = ASRUtils::symbol_get_past_external(name); - std::string nested_func_name = current_scope->get_unique_name("__asr_" + call_name, false); + std::string nested_func_name = func_scope->get_unique_name("__asr_" + call_name, false); SymbolInstantiator nested(al, context_map, type_subs, symbol_subs, func_scope, template_scope, nested_func_name); name = nested.instantiate_symbol(name2); name = nested.instantiate_body(ASR::down_cast(name), ASR::down_cast(name2)); @@ -428,7 +564,7 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatorget_counter() != current_scope->get_counter()) { + if (ASRUtils::symbol_parent_symtab(name)->get_counter() != current_scope->get_counter() && !ASR::is_a(*name)) { ADD_ASR_DEPENDENCIES(current_scope, name, dependencies); } return ASRUtils::make_FunctionCall_t_util(al, x->base.base.loc, name, x->m_original_name, @@ -472,7 +608,7 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatorget_counter() != current_scope->get_counter()) { + if (ASRUtils::symbol_parent_symtab(name)->get_counter() != current_scope->get_counter() && !ASR::is_a(*name)) { ADD_ASR_DEPENDENCIES(current_scope, name, dependencies); } return ASRUtils::make_SubroutineCall_t_util(al, x->base.base.loc, name /* change this */, @@ -488,6 +624,14 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicatorm_arg); + ASR::ttype_t *ttype = substitute_type(x->m_type); + ASR::expr_t *value = duplicate_expr(x->m_value); + return ASR::make_ArrayPhysicalCast_t(al, x->base.base.loc, + arg, x->m_old, x->m_new, ttype, value); + } + ASR::ttype_t* substitute_type(ASR::ttype_t *ttype) { switch (ttype->type) { case (ASR::ttypeType::TypeParameter) : { @@ -582,7 +726,7 @@ class SymbolInstantiator : public ASR::BaseExprStmtDuplicator& context_map, std::map type_subs, std::map symbol_subs, @@ -594,7 +738,7 @@ ASR::symbol_t* pass_instantiate_symbol(Allocator &al, return t.instantiate_symbol(sym2); } -ASR::symbol_t* pass_instantiate_function_body(Allocator &al, +ASR::symbol_t* instantiate_function_body(Allocator &al, std::map& context_map, std::map type_subs, std::map symbol_subs, @@ -605,9 +749,60 @@ ASR::symbol_t* pass_instantiate_function_body(Allocator &al, return t.instantiate_body(new_f, f); } -void check_restriction(std::map type_subs, +ASR::symbol_t* rename_symbol(Allocator &al, + std::map &type_subs, + SymbolTable *current_scope, + std::string new_sym_name, ASR::symbol_t *sym) { + SymbolRenamer t(al, type_subs, current_scope, new_sym_name); + return t.rename_symbol(sym); +} + +bool check_restriction(std::map type_subs, + std::map &symbol_subs, + ASR::Function_t *f, ASR::symbol_t *sym_arg) { + std::string f_name = f->m_name; + ASR::Function_t *arg = ASR::down_cast(ASRUtils::symbol_get_past_external(sym_arg)); + std::string arg_name = arg->m_name; + if (f->n_args != arg->n_args) { + return false; + } + for (size_t i = 0; i < f->n_args; i++) { + ASR::ttype_t *f_param = ASRUtils::expr_type(f->m_args[i]); + ASR::ttype_t *arg_param = ASRUtils::expr_type(arg->m_args[i]); + if (ASR::is_a(*f_param)) { + ASR::TypeParameter_t *f_tp + = ASR::down_cast(f_param); + if (!ASRUtils::check_equal_type(type_subs[f_tp->m_param], + arg_param)) { + return false; + } + } + } + if (f->m_return_var) { + if (!arg->m_return_var) { + return false; + } + ASR::ttype_t *f_ret = ASRUtils::expr_type(f->m_return_var); + ASR::ttype_t *arg_ret = ASRUtils::expr_type(arg->m_return_var); + if (ASR::is_a(*f_ret)) { + ASR::TypeParameter_t *return_tp + = ASR::down_cast(f_ret); + if (!ASRUtils::check_equal_type(type_subs[return_tp->m_param], arg_ret)) { + return false; + } + } + } else { + if (arg->m_return_var) { + return false; + } + } + symbol_subs[f_name] = sym_arg; + return true; +} + +void report_check_restriction(std::map type_subs, std::map &symbol_subs, - ASR::Function_t *f, ASR::symbol_t *sym_arg, const Location& loc, + ASR::Function_t *f, ASR::symbol_t *sym_arg, const Location &loc, diag::Diagnostics &diagnostics) { std::string f_name = f->m_name; ASR::Function_t *arg = ASR::down_cast(ASRUtils::symbol_get_past_external(sym_arg)); @@ -649,8 +844,6 @@ void check_restriction(std::map type_subs, {f->m_args[i]->base.loc}), diag::Label("Function's parameter " + avar + " of type " + atype, {arg->m_args[i]->base.loc}) - - } )); throw SemanticAbort(); diff --git a/src/libasr/pass/instantiate_template.h b/src/libasr/pass/instantiate_template.h index a7ba880ece0..253adc7de30 100644 --- a/src/libasr/pass/instantiate_template.h +++ b/src/libasr/pass/instantiate_template.h @@ -11,23 +11,32 @@ namespace LCompilers { * contain any type parameters and restrictions. No type checking * is executed here */ - ASR::symbol_t* pass_instantiate_symbol(Allocator &al, - std::map& context_map, + ASR::symbol_t* instantiate_symbol(Allocator &al, + std::map &context_map, std::map type_subs, std::map symbol_subs, SymbolTable *current_scope, SymbolTable *template_scope, std::string new_sym_name, ASR::symbol_t *sym); - ASR::symbol_t* pass_instantiate_function_body(Allocator &al, - std::map& context_map, + ASR::symbol_t* instantiate_function_body(Allocator &al, + std::map &context_map, std::map type_subs, std::map symbol_subs, SymbolTable *current_scope, SymbolTable *template_scope, ASR::Function_t *new_f, ASR::Function_t *f); - void check_restriction(std::map type_subs, + ASR::symbol_t* rename_symbol(Allocator &al, + std::map &type_subs, + SymbolTable *current_scope, + std::string new_sym_name, ASR::symbol_t *sym); + + bool check_restriction(std::map type_subs, + std::map &symbol_subs, + ASR::Function_t *f, ASR::symbol_t *sym_arg); + + void report_check_restriction(std::map type_subs, std::map &symbol_subs, - ASR::Function_t *f, ASR::symbol_t *sym_arg, const Location& loc, + ASR::Function_t *f, ASR::symbol_t *sym_arg, const Location &loc, diag::Diagnostics &diagnostics); } // namespace LCompilers diff --git a/src/libasr/pass/intrinsic_array_function_registry.h b/src/libasr/pass/intrinsic_array_function_registry.h index b9a175d2931..3c3bca0a6b1 100644 --- a/src/libasr/pass/intrinsic_array_function_registry.h +++ b/src/libasr/pass/intrinsic_array_function_registry.h @@ -68,7 +68,7 @@ namespace ArrIntrinsic { static inline void verify_array_int_real_cmplx(ASR::expr_t* array, ASR::ttype_t* return_type, const Location& loc, diag::Diagnostics& diagnostics, ASRUtils::IntrinsicArrayFunctions intrinsic_func_id) { - std::string intrinsic_func_name = ASRUtils::get_intrinsic_name(static_cast(intrinsic_func_id)); + std::string intrinsic_func_name = ASRUtils::get_array_intrinsic_name(static_cast(intrinsic_func_id)); ASR::ttype_t* array_type = ASRUtils::expr_type(array); ASRUtils::require_impl(ASRUtils::is_integer(*array_type) || ASRUtils::is_real(*array_type) || @@ -89,7 +89,7 @@ static inline void verify_array_int_real_cmplx(ASR::expr_t* array, ASR::ttype_t* static inline void verify_array_int_real(ASR::expr_t* array, ASR::ttype_t* return_type, const Location& loc, diag::Diagnostics& diagnostics, ASRUtils::IntrinsicArrayFunctions intrinsic_func_id) { - std::string intrinsic_func_name = ASRUtils::get_intrinsic_name(static_cast(intrinsic_func_id)); + std::string intrinsic_func_name = ASRUtils::get_array_intrinsic_name(static_cast(intrinsic_func_id)); ASR::ttype_t* array_type = ASRUtils::expr_type(array); ASRUtils::require_impl(ASRUtils::is_integer(*array_type) || ASRUtils::is_real(*array_type), @@ -109,7 +109,7 @@ static inline void verify_array_int_real(ASR::expr_t* array, ASR::ttype_t* retur static inline void verify_array_dim(ASR::expr_t* array, ASR::expr_t* dim, ASR::ttype_t* return_type, const Location& loc, diag::Diagnostics& diagnostics, ASRUtils::IntrinsicArrayFunctions intrinsic_func_id) { - std::string intrinsic_func_name = ASRUtils::get_intrinsic_name(static_cast(intrinsic_func_id)); + std::string intrinsic_func_name = ASRUtils::get_array_intrinsic_name(static_cast(intrinsic_func_id)); ASR::ttype_t* array_type = ASRUtils::expr_type(array); ASRUtils::require_impl(ASRUtils::is_integer(*array_type) || ASRUtils::is_real(*array_type) || @@ -135,7 +135,7 @@ static inline void verify_array_dim(ASR::expr_t* array, ASR::expr_t* dim, static inline void verify_args(const ASR::IntrinsicArrayFunction_t& x, diag::Diagnostics& diagnostics, ASRUtils::IntrinsicArrayFunctions intrinsic_func_id, verify_array_func verify_array) { - std::string intrinsic_func_name = ASRUtils::get_intrinsic_name(static_cast(intrinsic_func_id)); + std::string intrinsic_func_name = ASRUtils::get_array_intrinsic_name(static_cast(intrinsic_func_id)); ASRUtils::require_impl(x.n_args >= 1, intrinsic_func_name + " intrinsic must accept at least one argument", x.base.base.loc, diagnostics); ASRUtils::require_impl(x.m_args[0] != nullptr, "Array argument to " + intrinsic_func_name + " intrinsic cannot be nullptr", @@ -197,7 +197,7 @@ static inline ASR::asr_t* create_ArrIntrinsic( Allocator& al, const Location& loc, Vec& args, const std::function err, ASRUtils::IntrinsicArrayFunctions intrinsic_func_id) { - std::string intrinsic_func_name = ASRUtils::get_intrinsic_name(static_cast(intrinsic_func_id)); + std::string intrinsic_func_name = ASRUtils::get_array_intrinsic_name(static_cast(intrinsic_func_id)); int64_t id_array = 0, id_array_dim = 1, id_array_mask = 2; int64_t id_array_dim_mask = 3; int64_t overload_id = id_array; @@ -418,7 +418,7 @@ static inline ASR::expr_t* instantiate_ArrIntrinsic(Allocator &al, int64_t overload_id, ASRUtils::IntrinsicArrayFunctions intrinsic_func_id, get_initial_value_func get_initial_value, elemental_operation_func elemental_operation) { - std::string intrinsic_func_name = ASRUtils::get_intrinsic_name(static_cast(intrinsic_func_id)); + std::string intrinsic_func_name = ASRUtils::get_array_intrinsic_name(static_cast(intrinsic_func_id)); ASRBuilder builder(al, loc); ASRBuilder& b = builder; int64_t id_array = 0, id_array_dim = 1, id_array_mask = 2; @@ -534,7 +534,7 @@ static inline ASR::expr_t* instantiate_ArrIntrinsic(Allocator &al, static inline void verify_MaxMinLoc_args(const ASR::IntrinsicArrayFunction_t& x, diag::Diagnostics& diagnostics) { - std::string intrinsic_name = get_intrinsic_name( + std::string intrinsic_name = get_array_intrinsic_name( static_cast(x.m_arr_intrinsic_id)); require_impl(x.n_args >= 1, "`"+ intrinsic_name +"` intrinsic " "must accept at least one argument", x.base.base.loc, diagnostics); @@ -579,7 +579,7 @@ static inline ASR::expr_t *eval_MaxMinLoc(Allocator &al, const Location &loc, static inline ASR::asr_t* create_MaxMinLoc(Allocator& al, const Location& loc, Vec& args, int intrinsic_id, const std::function err) { - std::string intrinsic_name = get_intrinsic_name(static_cast(intrinsic_id)); + std::string intrinsic_name = get_array_intrinsic_name(static_cast(intrinsic_id)); ASR::ttype_t *array_type = expr_type(args[0]); if ( !is_array(array_type) ) { err("`array` argument of `"+ intrinsic_name +"` must be an array", loc); @@ -647,7 +647,7 @@ static inline ASR::expr_t *instantiate_MaxMinLoc(Allocator &al, const Location &loc, SymbolTable *scope, int intrinsic_id, Vec& arg_types, ASR::ttype_t *return_type, Vec& m_args, int64_t /*overload_id*/) { - std::string intrinsic_name = get_intrinsic_name(static_cast(intrinsic_id)); + std::string intrinsic_name = get_array_intrinsic_name(static_cast(intrinsic_id)); declare_basic_variables("_lcompilers_" + intrinsic_name) /* * max_index = 1; min_index @@ -813,7 +813,8 @@ namespace Shape { const Location &loc, SymbolTable *scope, Vec& arg_types, ASR::ttype_t *return_type, Vec& new_args, int64_t) { declare_basic_variables("_lcompilers_shape"); - fill_func_arg("source", arg_types[0]); + fill_func_arg("source", ASRUtils::duplicate_type_with_empty_dims(al, + arg_types[0])); auto result = declare(fn_name, return_type, ReturnVar); int iter = extract_n_dims_from_ttype(arg_types[0]) + 1; auto i = declare("i", int32, Local); @@ -1377,7 +1378,7 @@ namespace Merge { auto mask_arg = declare("mask", mask_type, In); args.push_back(al, mask_arg); // TODO: In case of Character type, set len of ReturnVar to len(tsource) expression - auto result = declare("merge", tsource_type, ReturnVar); + auto result = declare("merge", type_get_past_allocatable(return_type), ReturnVar); { Vec if_body; if_body.reserve(al, 1); diff --git a/src/libasr/pass/intrinsic_function.cpp b/src/libasr/pass/intrinsic_function.cpp index 89f3572bd1f..64f8f37c406 100644 --- a/src/libasr/pass/intrinsic_function.cpp +++ b/src/libasr/pass/intrinsic_function.cpp @@ -212,10 +212,17 @@ class ReplaceFunctionCallReturningArray: public ASR::BaseExprReplacer alloc_args; alloc_args.reserve(al, 1); alloc_args.push_back(al, alloc_arg); + Vec to_be_deallocated; + to_be_deallocated.reserve(al, alloc_args.size()); + for( size_t i = 0; i < alloc_args.size(); i++ ) { + to_be_deallocated.push_back(al, alloc_args.p[i].m_a); + } ASR::stmt_t* allocate_stmt = ASRUtils::STMT(ASR::make_Allocate_t( al, loc_, alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)); Vec if_body; - if_body.reserve(al, 1); + if_body.reserve(al, 2); + if_body.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( + al, loc, to_be_deallocated.p, to_be_deallocated.size()))); if_body.push_back(al, allocate_stmt); ASR::stmt_t* if_ = ASRUtils::STMT(ASR::make_If_t(al, loc_, test_expr, if_body.p, if_body.size(), else_, else_n)); @@ -345,6 +352,7 @@ class ReplaceFunctionCallReturningArrayVisitor : public ASR::CallReplacerOnExpre parent_body->push_back(al, pass_result[j]); } } + for (size_t i=0; itype) { case ASR::ttypeType::Integer : { return EXPR(ASR::make_IntegerBinOp_t(al, loc, left, @@ -368,6 +394,7 @@ class ASRBuilder { ASR::expr_t *Mul(ASR::expr_t *left, ASR::expr_t *right) { LCOMPILERS_ASSERT(check_equal_type(expr_type(left), expr_type(right))); ASR::ttype_t *type = expr_type(left); + ASRUtils::make_ArrayBroadcast_t_util(al, loc, left, right); switch (type->type) { case ASR::ttypeType::Integer : { return EXPR(ASR::make_IntegerBinOp_t(al, loc, left, @@ -1738,6 +1765,136 @@ namespace Aint { } // namespace Aint +namespace Sqrt { + + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, + diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 1, + "ASR Verify: Call `sqrt` must have exactly one argument", + x.base.base.loc, diagnostics); + ASR::ttype_t *type = ASRUtils::expr_type(x.m_args[0]); + ASRUtils::require_impl(ASRUtils::is_real(*type) || ASRUtils::is_complex(*type), + "ASR Verify: Arguments to `sqrt` must be of real or complex type", + x.base.base.loc, diagnostics); + } + + static ASR::expr_t *eval_Sqrt(Allocator &al, const Location &loc, + ASR::ttype_t* arg_type, Vec &args) { + if (is_real(*arg_type)) { + double val = ASR::down_cast(expr_value(args[0]))->m_r; + return f(std::sqrt(val), arg_type); + } else { + std::complex crv; + if( ASRUtils::extract_value(args[0], crv) ) { + std::complex val = std::sqrt(crv); + return ASRUtils::EXPR(ASR::make_ComplexConstant_t( + al, loc, val.real(), val.imag(), arg_type)); + } else { + return nullptr; + } + } + } + + static inline ASR::asr_t* create_Sqrt(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + ASR::ttype_t* return_type = expr_type(args[0]); + if ( args.n != 1 ) { + err("Intrinsic `sqrt` accepts exactly one argument", loc); + } else if ( !(is_real(*return_type) || is_complex(*return_type)) ) { + err("Argument of the `sqrt` must be Real or Complex", loc); + } + ASR::expr_t *m_value = nullptr; + if (all_args_evaluated(args)) { + m_value = eval_Sqrt(al, loc, return_type, args); + } + return ASR::make_IntrinsicScalarFunction_t(al, loc, + static_cast(IntrinsicScalarFunctions::Sqrt), + args.p, args.n, 0, return_type, m_value); + } + + static inline ASR::expr_t* instantiate_Sqrt(Allocator &al, const Location &loc, + SymbolTable *scope, Vec& arg_types, ASR::ttype_t *return_type, + Vec& new_args, int64_t overload_id) { + ASR::ttype_t* arg_type = arg_types[0]; + if (is_real(*arg_type)) { + return EXPR(ASR::make_IntrinsicFunctionSqrt_t(al, loc, + new_args[0].m_value, return_type, nullptr)); + } else { + return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope, + "sqrt", arg_type, return_type, new_args, overload_id); + } + } + +} // namespace Sqrt + +namespace Sngl { + + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, + diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 1, + "ASR Verify: Call `sngl` must have exactly one argument", + x.base.base.loc, diagnostics); + ASR::ttype_t *type = ASRUtils::expr_type(x.m_args[0]); + ASRUtils::require_impl(ASRUtils::is_real(*type), + "ASR Verify: Arguments to `sngl` must be of real type", + x.base.base.loc, diagnostics); + } + + static ASR::expr_t *eval_Sngl(Allocator &al, const Location &loc, + ASR::ttype_t* arg_type, Vec &args) { + double val = ASR::down_cast(expr_value(args[0]))->m_r; + return f(val, arg_type); + } + + static inline ASR::asr_t* create_Sngl( + Allocator& al, const Location& loc, Vec& args, + const std::function err) { + ASR::ttype_t* return_type = real32; + if ( args.n != 1 ) { + err("Intrinsic `sngl` accepts exactly one argument", loc); + } else if ( !is_real(*expr_type(args[0])) ) { + err("Argument of the `sngl` must be Real", loc); + } + Vec m_args; m_args.reserve(al, 1); + m_args.push_back(al, args[0]); + ASR::expr_t *m_value = nullptr; + if (all_args_evaluated(m_args)) { + m_value = eval_Sngl(al, loc, return_type, m_args); + } + return ASR::make_IntrinsicScalarFunction_t(al, loc, + static_cast(IntrinsicScalarFunctions::Sngl), + m_args.p, m_args.n, 0, return_type, m_value); + } + + static inline ASR::expr_t* instantiate_Sngl(Allocator &al, const Location &loc, + SymbolTable *scope, Vec& arg_types, ASR::ttype_t *return_type, + Vec& new_args, int64_t /*overload_id*/) { + std::string func_name = "_lcompilers_sngl_" + type_to_str_python(arg_types[0]); + std::string fn_name = scope->get_unique_name(func_name); + SymbolTable *fn_symtab = al.make_new(scope); + Vec args; + args.reserve(al, new_args.size()); + ASRBuilder b(al, loc); + Vec body; body.reserve(al, 1); + SetChar dep; dep.reserve(al, 1); + if (scope->get_symbol(fn_name)) { + ASR::symbol_t *s = scope->get_symbol(fn_name); + ASR::Function_t *f = ASR::down_cast(s); + return b.Call(s, new_args, expr_type(f->m_return_var), nullptr); + } + fill_func_arg("a", arg_types[0]); + auto result = declare(fn_name, return_type, ReturnVar); + body.push_back(al, b.Assignment(result, r2r32(args[0]))); + + ASR::symbol_t *f_sym = make_ASR_Function_t(fn_name, fn_symtab, dep, args, + body, result, ASR::abiType::Source, ASR::deftypeType::Implementation, nullptr); + scope->add_symbol(fn_name, f_sym); + return b.Call(f_sym, new_args, return_type, nullptr); + } + +} // namespace Sngl + namespace FMA { static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) { @@ -1988,6 +2145,7 @@ namespace FlipSign { namespace FloorDiv { + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) { ASRUtils::require_impl(x.n_args == 2, "ASR Verify: Call to FloorDiv must have exactly 2 arguments", @@ -2004,6 +2162,7 @@ namespace FloorDiv { x.base.base.loc, diagnostics); } + static ASR::expr_t *eval_FloorDiv(Allocator &al, const Location &loc, ASR::ttype_t* t1, Vec &args) { ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]); @@ -2019,6 +2178,7 @@ namespace FloorDiv { bool is_logical1 = is_logical(*type1); bool is_logical2 = is_logical(*type2); + if (is_int1 && is_int2) { int64_t a = ASR::down_cast(args[0])->m_n; int64_t b = ASR::down_cast(args[1])->m_n; @@ -2044,6 +2204,8 @@ namespace FloorDiv { return nullptr; } + + static inline ASR::asr_t* create_FloorDiv(Allocator& al, const Location& loc, Vec& args, const std::function err) { @@ -2102,6 +2264,7 @@ namespace FloorDiv { return result */ + ASR::expr_t *op1 = r64Div(CastingUtil::perform_casting(args[0], arg_types[0], real64, al, loc), CastingUtil::perform_casting(args[1], arg_types[1], real64, al, loc)); body.push_back(al, b.Assignment(r, op1)); @@ -2118,6 +2281,208 @@ namespace FloorDiv { } // namespace FloorDiv +namespace Mod { + + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, + "ASR Verify: Call to Mod must have exactly 2 arguments", + x.base.base.loc, diagnostics); + ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]); + ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]); + ASRUtils::require_impl((is_integer(*type1) && is_integer(*type2)) || + (is_real(*type1) && is_real(*type2)), + "ASR Verify: Arguments to Mod must be of real or integer type", + x.base.base.loc, diagnostics); + } + + static ASR::expr_t *eval_Mod(Allocator &al, const Location &loc, + ASR::ttype_t* t1, Vec &args) { + bool is_real1 = is_real(*ASRUtils::expr_type(args[0])); + bool is_real2 = is_real(*ASRUtils::expr_type(args[1])); + bool is_int1 = is_integer(*ASRUtils::expr_type(args[0])); + bool is_int2 = is_integer(*ASRUtils::expr_type(args[1])); + + if (is_int1 && is_int2) { + int64_t a = ASR::down_cast(args[0])->m_n; + int64_t b = ASR::down_cast(args[1])->m_n; + return make_ConstantWithType(make_IntegerConstant_t, a % b, t1, loc); + } else if (is_real1 && is_real2) { + double a = ASR::down_cast(args[0])->m_r; + double b = ASR::down_cast(args[1])->m_r; + return make_ConstantWithType(make_RealConstant_t, std::fmod(a, b), t1, loc); + } + return nullptr; + } + + static inline ASR::asr_t* create_Mod(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 2) { + err("Intrinsic Mod function accepts exactly 2 arguments", loc); + } + ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]); + ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]); + if (!((ASRUtils::is_integer(*type1) && ASRUtils::is_integer(*type2)) || + (ASRUtils::is_real(*type1) && ASRUtils::is_real(*type2)))) { + err("Argument of the Mod function must be either Real or Integer", + args[0]->base.loc); + } + ASR::expr_t *m_value = nullptr; + if (all_args_evaluated(args)) { + Vec arg_values; arg_values.reserve(al, 2); + arg_values.push_back(al, expr_value(args[0])); + arg_values.push_back(al, expr_value(args[1])); + m_value = eval_Mod(al, loc, expr_type(args[1]), arg_values); + } + return ASR::make_IntrinsicScalarFunction_t(al, loc, + static_cast(IntrinsicScalarFunctions::Mod), + args.p, args.n, 0, ASRUtils::expr_type(args[0]), m_value); + } + + static inline ASR::expr_t* instantiate_Mod(Allocator &al, const Location &loc, + SymbolTable *scope, Vec& arg_types, ASR::ttype_t *return_type, + Vec& new_args, int64_t /*overload_id*/) { + declare_basic_variables("_lcompilers_optimization_mod_" + type_to_str_python(arg_types[1])); + fill_func_arg("a", arg_types[0]); + fill_func_arg("p", arg_types[1]); + auto result = declare(fn_name, return_type, ReturnVar); + /* + function modi32i32(a, p) result(d) + integer(int32), intent(in) :: a, p + integer(int32) :: q + q = a/p + d = a - p*q + end function + */ + + ASR::expr_t *q = nullptr, *op1 = nullptr, *op2 = nullptr; + if (is_real(*arg_types[1])) { + int kind = ASRUtils::extract_kind_from_ttype_t(arg_types[1]); + if (kind == 4) { + q = r2i32(r32Div(args[0], args[1])); + op1 = r32Mul(args[1], i2r32(q)); + op2 = r32Sub(args[0], op1); + } else { + q = r2i64(r64Div(args[0], args[1])); + op1 = r64Mul(args[1], i2r64(q)); + op2 = r64Sub(args[0], op1); + } + } else { + q = iDiv(args[0], args[1]); + op1 = iMul(args[1], q); + op2 = iSub(args[0], op1); + } + body.push_back(al, b.Assignment(result, op2)); + + ASR::symbol_t *f_sym = make_ASR_Function_t(fn_name, fn_symtab, dep, args, + body, result, ASR::abiType::Source, ASR::deftypeType::Implementation, nullptr); + scope->add_symbol(fn_name, f_sym); + return b.Call(f_sym, new_args, return_type, nullptr); + } + +} // namespace Mod + +namespace Trailz { + + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 1, + "ASR Verify: Call to Trailz must have exactly 1 argument", + x.base.base.loc, diagnostics); + ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]); + ASRUtils::require_impl(is_integer(*type1), + "ASR Verify: Arguments to Trailz must be of integer type", + x.base.base.loc, diagnostics); + } + + static ASR::expr_t *eval_Trailz(Allocator &al, const Location &loc, + ASR::ttype_t* t1, Vec &args) { + int64_t a = ASR::down_cast(args[0])->m_n; + int64_t trailing_zeros = ASRUtils::compute_trailing_zeros(a); + return make_ConstantWithType(make_IntegerConstant_t, trailing_zeros, t1, loc); + } + + static inline ASR::asr_t* create_Trailz(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 1) { + err("Intrinsic Trailz function accepts exactly 1 arguments", loc); + } + ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]); + if (!(ASRUtils::is_integer(*type1))) { + err("Argument of the Trailz function must be Integer", + args[0]->base.loc); + } + ASR::expr_t *m_value = nullptr; + if (all_args_evaluated(args)) { + Vec arg_values; arg_values.reserve(al, 1); + arg_values.push_back(al, expr_value(args[0])); + m_value = eval_Trailz(al, loc, expr_type(args[0]), arg_values); + } + return ASR::make_IntrinsicScalarFunction_t(al, loc, + static_cast(IntrinsicScalarFunctions::Trailz), + args.p, args.n, 0, ASRUtils::expr_type(args[0]), m_value); + } + + static inline ASR::expr_t* instantiate_Trailz(Allocator &al, const Location &loc, + SymbolTable *scope, Vec& arg_types, ASR::ttype_t *return_type, + Vec& new_args, int64_t /*overload_id*/) { + declare_basic_variables("_lcompilers_optimization_trailz_" + type_to_str_python(arg_types[0])); + fill_func_arg("n", arg_types[0]); + auto result = declare(fn_name, arg_types[0], ReturnVar); + // This is not the most efficient way to do this, but it works for now. + /* + function trailz(n) result(result) + integer :: n + integer :: result + result = 0 + if (n == 0) then + result = 32 + else + do while (mod(n,2) == 0) + n = n/2 + result = result + 1 + end do + end if + end function + */ + + body.push_back(al, b.Assignment(result, i(0, arg_types[0]))); + ASR::expr_t *two = i(2, arg_types[0]); + int arg_0_kind = ASRUtils::extract_kind_from_ttype_t(arg_types[0]); + + Vec arg_types_mod; arg_types_mod.reserve(al, 2); + arg_types_mod.push_back(al, arg_types[0]); arg_types_mod.push_back(al, ASRUtils::expr_type(two)); + + Vec new_args_mod; new_args_mod.reserve(al, 2); + ASR::call_arg_t arg1; arg1.loc = loc; arg1.m_value = args[0]; + ASR::call_arg_t arg2; arg2.loc = loc; arg2.m_value = two; + new_args_mod.push_back(al, arg1); new_args_mod.push_back(al, arg2); + + ASR::expr_t* func_call_mod = Mod::instantiate_Mod(al, loc, scope, arg_types_mod, return_type, new_args_mod, 0); + ASR::expr_t *cond = iEq(func_call_mod, i(0, arg_types[0])); + + std::vector while_loop_body; + if (arg_0_kind == 4) { + while_loop_body.push_back(b.Assignment(args[0], iDiv(args[0], two))); + while_loop_body.push_back(b.Assignment(result, iAdd(result, i(1, arg_types[0])))); + } else { + while_loop_body.push_back(b.Assignment(args[0], iDiv64(args[0], two))); + while_loop_body.push_back(b.Assignment(result, iAdd64(result, i(1, arg_types[0])))); + } + + ASR::expr_t* check_zero = iEq(args[0], i(0, arg_types[0])); + std::vector if_body; if_body.push_back(b.Assignment(result, i(32, arg_types[0]))); + std::vector else_body; else_body.push_back(b.While(cond, while_loop_body)); + body.push_back(al, b.If(check_zero, if_body, else_body)); + + ASR::symbol_t *f_sym = make_ASR_Function_t(fn_name, fn_symtab, dep, args, + body, result, ASR::abiType::Source, ASR::deftypeType::Implementation, nullptr); + scope->add_symbol(fn_name, f_sym); + return b.Call(f_sym, new_args, return_type, nullptr); + } + +} // namespace Trailz + #define create_exp_macro(X, stdeval) \ namespace X { \ static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \ @@ -3209,7 +3574,6 @@ create_symbolic_query_macro(SymbolicPowQ) create_symbolic_query_macro(SymbolicLogQ) create_symbolic_query_macro(SymbolicSinQ) - #define create_symbolic_unary_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ @@ -3299,6 +3663,10 @@ namespace IntrinsicScalarFunctionRegistry { {&FlipSign::instantiate_FlipSign, &FlipSign::verify_args}}, {static_cast(IntrinsicScalarFunctions::FloorDiv), {&FloorDiv::instantiate_FloorDiv, &FloorDiv::verify_args}}, + {static_cast(IntrinsicScalarFunctions::Mod), + {&Mod::instantiate_Mod, &Mod::verify_args}}, + {static_cast(IntrinsicScalarFunctions::Trailz), + {&Trailz::instantiate_Trailz, &Trailz::verify_args}}, {static_cast(IntrinsicScalarFunctions::Abs), {&Abs::instantiate_Abs, &Abs::verify_args}}, {static_cast(IntrinsicScalarFunctions::Partition), @@ -3329,6 +3697,10 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &Radix::verify_args}}, {static_cast(IntrinsicScalarFunctions::Aint), {&Aint::instantiate_Aint, &Aint::verify_args}}, + {static_cast(IntrinsicScalarFunctions::Sqrt), + {&Sqrt::instantiate_Sqrt, &Sqrt::verify_args}}, + {static_cast(IntrinsicScalarFunctions::Sngl), + {&Sngl::instantiate_Sngl, &Sngl::verify_args}}, {static_cast(IntrinsicScalarFunctions::SignFromValue), {&SignFromValue::instantiate_SignFromValue, &SignFromValue::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicSymbol), @@ -3419,6 +3791,10 @@ namespace IntrinsicScalarFunctionRegistry { "flipsign"}, {static_cast(IntrinsicScalarFunctions::FloorDiv), "floordiv"}, + {static_cast(IntrinsicScalarFunctions::Mod), + "mod"}, + {static_cast(IntrinsicScalarFunctions::Trailz), + "trailz"}, {static_cast(IntrinsicScalarFunctions::Expm1), "expm1"}, {static_cast(IntrinsicScalarFunctions::ListIndex), @@ -3447,6 +3823,10 @@ namespace IntrinsicScalarFunctionRegistry { "sign"}, {static_cast(IntrinsicScalarFunctions::Aint), "aint"}, + {static_cast(IntrinsicScalarFunctions::Sqrt), + "sqrt"}, + {static_cast(IntrinsicScalarFunctions::Sngl), + "sngl"}, {static_cast(IntrinsicScalarFunctions::SignFromValue), "signfromvalue"}, {static_cast(IntrinsicScalarFunctions::SymbolicSymbol), @@ -3520,6 +3900,8 @@ namespace IntrinsicScalarFunctionRegistry { {"expm1", {&Expm1::create_Expm1, &Expm1::eval_Expm1}}, {"fma", {&FMA::create_FMA, &FMA::eval_FMA}}, {"floordiv", {&FloorDiv::create_FloorDiv, &FloorDiv::eval_FloorDiv}}, + {"mod", {&Mod::create_Mod, &Mod::eval_Mod}}, + {"trailz", {&Trailz::create_Trailz, &Trailz::eval_Trailz}}, {"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}}, {"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}}, {"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}}, @@ -3534,6 +3916,8 @@ namespace IntrinsicScalarFunctionRegistry { {"radix", {&Radix::create_Radix, nullptr}}, {"sign", {&Sign::create_Sign, &Sign::eval_Sign}}, {"aint", {&Aint::create_Aint, &Aint::eval_Aint}}, + {"sqrt", {&Sqrt::create_Sqrt, &Sqrt::eval_Sqrt}}, + {"sngl", {&Sngl::create_Sngl, &Sngl::eval_Sngl}}, {"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}}, {"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}}, {"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}}, @@ -3579,6 +3963,9 @@ namespace IntrinsicScalarFunctionRegistry { id_ == IntrinsicScalarFunctions::Exp || id_ == IntrinsicScalarFunctions::Exp2 || id_ == IntrinsicScalarFunctions::Expm1 || + id_ == IntrinsicScalarFunctions::Min || + id_ == IntrinsicScalarFunctions::Max || + id_ == IntrinsicScalarFunctions::Sqrt || id_ == IntrinsicScalarFunctions::SymbolicSymbol); } diff --git a/src/libasr/pass/nested_vars.cpp b/src/libasr/pass/nested_vars.cpp index cf3bf090f6b..e2ec051f600 100644 --- a/src/libasr/pass/nested_vars.cpp +++ b/src/libasr/pass/nested_vars.cpp @@ -171,6 +171,10 @@ class NestedVarVisitor : public ASR::BaseWalkVisitor } } } + + void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t& x) { + visit_expr(*x.m_array); + } }; @@ -219,6 +223,13 @@ class ReplacerNestedVars: public ASR::BaseExprReplacer { void replace_Array(ASR::Array_t* /*x*/) { return ; } + + void replace_ArrayBroadcast(ASR::ArrayBroadcast_t* x) { + ASR::expr_t** current_expr_copy_161 = current_expr; + current_expr = &(x->m_array); + replace_expr(x->m_array); + current_expr = current_expr_copy_161; + } }; class ReplaceNestedVisitor: public ASR::CallReplacerOnExpressionsVisitor { @@ -437,6 +448,16 @@ class ReplaceNestedVisitor: public ASR::CallReplacerOnExpressionsVisitor(&(x.m_array)); + call_replacer(); + current_expr = current_expr_copy_269; + if( x.m_array ) { + visit_expr(*x.m_array); + } + } + }; class AssignNestedVars: public PassUtils::PassVisitor { @@ -614,6 +635,10 @@ class AssignNestedVars: public PassUtils::PassVisitor { void visit_Array(const ASR::Array_t& /*x*/) { return ; } + + void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t& x) { + visit_expr(*x.m_array); + } }; void pass_nested_vars(Allocator &al, ASR::TranslationUnit_t &unit, diff --git a/src/libasr/pass/pass_array_by_data.cpp b/src/libasr/pass/pass_array_by_data.cpp index a9dbcdcaa11..20f925f17c7 100644 --- a/src/libasr/pass/pass_array_by_data.cpp +++ b/src/libasr/pass/pass_array_by_data.cpp @@ -93,10 +93,19 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor(current_scope); ASRUtils::SymbolDuplicator symbol_duplicator(al); + // first duplicate the external symbols + // so they can be referenced by derived_type for( auto& item: x->m_symtab->get_scope() ) { - symbol_duplicator.duplicate_symbol(item.second, new_symtab); + if (ASR::is_a(*item.second)) { + symbol_duplicator.duplicate_symbol(item.second, new_symtab); + } } - Vec new_args; + for( auto& item: x->m_symtab->get_scope() ) { + if (!ASR::is_a(*item.second)) { + symbol_duplicator.duplicate_symbol(item.second, new_symtab); + } + } + Vec new_args; std::string suffix = ""; new_args.reserve(al, x->n_args); ASR::expr_t* return_var = nullptr; @@ -204,21 +213,6 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitorn_args = new_args.size(); } - void visit_TranslationUnit(const ASR::TranslationUnit_t& x) { - // Visit Module first so that all functions in it are updated - for (auto &a : x.m_symtab->get_scope()) { - if( ASR::is_a(*a.second) ) { - this->visit_symbol(*a.second); - } - } - - // Visit all other symbols - for (auto &a : x.m_symtab->get_scope()) { - if( !ASR::is_a(*a.second) ) { - this->visit_symbol(*a.second); - } - } - } template bool visit_SymbolContainingFunctions(const T& x, @@ -250,6 +244,25 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitorget_scope()) { + if( ASR::is_a(*a.second) ) { + this->visit_symbol(*a.second); + } + } + + // Visit the program + for (auto &a : x.m_symtab->get_scope()) { + if( ASR::is_a(*a.second) ) { + this->visit_symbol(*a.second); + } + } + } + void visit_Program(const ASR::Program_t& x) { bfs_visit_SymbolContainingFunctions() } @@ -294,8 +307,21 @@ class EditProcedureReplacer: public ASR::BaseExprReplacer void replace_Var(ASR::Var_t* x) { ASR::symbol_t* x_sym_ = x->m_v; - if ( v.proc2newproc.find(x_sym_) != v.proc2newproc.end() ) { - x->m_v = v.proc2newproc[x_sym_].first; + bool is_external = ASR::is_a(*x_sym_); + if ( v.proc2newproc.find(ASRUtils::symbol_get_past_external(x_sym_)) != v.proc2newproc.end() ) { + x->m_v = v.proc2newproc[ASRUtils::symbol_get_past_external(x_sym_)].first; + if( is_external ) { + ASR::ExternalSymbol_t* x_sym_ext = ASR::down_cast(x_sym_); + std::string new_func_sym_name = current_scope->get_unique_name(ASRUtils::symbol_name( + ASRUtils::symbol_get_past_external(x_sym_))); + ASR::symbol_t* new_func_sym_ = ASR::down_cast( + ASR::make_ExternalSymbol_t(v.al, x_sym_->base.loc, current_scope, + s2c(v.al, new_func_sym_name), x->m_v, x_sym_ext->m_module_name, + x_sym_ext->m_scope_names, x_sym_ext->n_scope_names, ASRUtils::symbol_name(x->m_v), + x_sym_ext->m_access)); + current_scope->add_symbol(new_func_sym_name, new_func_sym_); + x->m_v = new_func_sym_; + } return ; } @@ -580,8 +606,9 @@ class RemoveArrayByDescriptorProceduresVisitor : public PassUtils::PassVisitor(x); + template + void visit_Unit(const T& x) { + T& xx = const_cast(x); current_scope = xx.m_symtab; std::vector to_be_erased; @@ -599,23 +626,16 @@ class RemoveArrayByDescriptorProceduresVisitor : public PassUtils::PassVisitor(x); - current_scope = xx.m_symtab; - - std::vector to_be_erased; + void visit_TranslationUnit(const ASR::TranslationUnit_t& x) { + visit_Unit(x); + } - for( auto& item: current_scope->get_scope() ) { - if( v.proc2newproc.find(item.second) != v.proc2newproc.end() && - not_to_be_erased.find(item.second) == not_to_be_erased.end() ) { - LCOMPILERS_ASSERT(item.first == ASRUtils::symbol_name(item.second)) - to_be_erased.push_back(item.first); - } - } + void visit_Program(const ASR::Program_t& x) { + visit_Unit(x); + } - for (auto &item: to_be_erased) { - current_scope->erase_symbol(item); - } + void visit_Function(const ASR::Function_t& x) { + visit_Unit(x); } }; diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index 0983bb8587a..f419726036d 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -48,7 +48,9 @@ #include #include #include +#include #include +#include #include #include @@ -100,7 +102,8 @@ namespace LCompilers { {"nested_vars", &pass_nested_vars}, {"where", &pass_replace_where}, {"print_struct_type", &pass_replace_print_struct_type}, - {"unique_symbols", &pass_unique_symbols} + {"unique_symbols", &pass_unique_symbols}, + {"insert_deallocate", &pass_insert_deallocate} }; bool is_fast; @@ -108,6 +111,7 @@ namespace LCompilers { bool c_skip_pass; // This will contain the passes that are to be skipped in C public: + bool rtlib=false; void apply_passes(Allocator& al, ASR::TranslationUnit_t* asr, std::vector& passes, PassOptions &pass_options, @@ -154,14 +158,6 @@ namespace LCompilers { std::cerr << "ASR Pass starts: '" << passes[i] << "'\n"; } _passes_db[passes[i]](al, *asr, pass_options); - if (pass_options.dump_all_passes) { - std::string str_i = std::to_string(i+1); - if ( i < 9 ) str_i = "0" + str_i; - std::ofstream outfile ("pass_" + str_i + "_" + passes[i] + ".clj"); - outfile << ";; ASR after applying the pass: " << passes[i] - << "\n" << pickle(*asr, false, true) << "\n"; - outfile.close(); - } #if defined(WITH_LFORTRAN_ASSERT) if (!asr_verify(*asr, true, diagnostics)) { std::cerr << diagnostics.render2(); @@ -175,8 +171,6 @@ namespace LCompilers { } } - bool rtlib=false; - void _parse_pass_arg(std::string& arg, std::vector& passes) { if (arg == "") return; @@ -211,16 +205,13 @@ namespace LCompilers { "implied_do_loops", "class_constructor", "pass_list_expr", - // "arr_slice", TODO: Remove ``arr_slice.cpp`` completely "where", "subroutine_from_function", "array_op", - // "subroutine_from_function", "symbolic", "intrinsic_function", "subroutine_from_function", "array_op", - // "subroutine_from_function", "pass_array_by_data", "print_struct_type", "print_arr", @@ -233,29 +224,31 @@ namespace LCompilers { "inline_function_calls", "unused_functions", "transform_optional_argument_functions", - "unique_symbols" + "unique_symbols", + "insert_deallocate" }; _with_optimization_passes = { + "nested_vars", "global_stmts", "init_expr", "implied_do_loops", "class_constructor", - "pass_array_by_data", - // "arr_slice", TODO: Remove ``arr_slice.cpp`` completely + "pass_list_expr", + "where", "subroutine_from_function", "array_op", + "symbolic", "intrinsic_function", "subroutine_from_function", "array_op", + "pass_array_by_data", "print_struct_type", "print_arr", "print_list_tuple", "print_struct_type", "loop_vectorise", - "loop_unroll", "array_dim_intrinsics_update", - "where", "do_loops", "forall", "dead_code_removal", @@ -266,8 +259,9 @@ namespace LCompilers { "div_to_mul", "fma", "transform_optional_argument_functions", - // "inline_function_calls", TODO: Uncomment later - "unique_symbols" + "inline_function_calls", + "unique_symbols", + "insert_deallocate" }; // These are re-write passes which are already handled @@ -307,6 +301,76 @@ namespace LCompilers { } } + void dump_all_passes(Allocator& al, ASR::TranslationUnit_t* asr, + PassOptions &pass_options, + [[maybe_unused]] diag::Diagnostics &diagnostics, LocationManager &lm) { + std::vector passes; + if (pass_options.fast) { + passes = _with_optimization_passes; + } else { + passes = _passes; + } + for (size_t i = 0; i < passes.size(); i++) { + // TODO: rework the whole pass manager: construct the passes + // ahead of time (not at the last minute), and remove this much + // earlier + // Note: this is not enough for rtlib, we also need to include + // it + if (pass_options.verbose) { + std::cerr << "ASR Pass starts: '" << passes[i] << "'\n"; + } + _passes_db[passes[i]](al, *asr, pass_options); + if (pass_options.dump_all_passes) { + std::string str_i = std::to_string(i+1); + if ( i < 9 ) str_i = "0" + str_i; + if (pass_options.json) { + std::ofstream outfile ("pass_json_" + str_i + "_" + passes[i] + ".json"); + outfile << pickle_json(*asr, lm, pass_options.no_loc, pass_options.with_intrinsic_mods) << "\n"; + outfile.close(); + } + if (pass_options.tree) { + std::ofstream outfile ("pass_tree_" + str_i + "_" + passes[i] + ".txt"); + outfile << pickle_tree(*asr, false, pass_options.with_intrinsic_mods) << "\n"; + outfile.close(); + } + if (pass_options.visualize) { + std::string json = pickle_json(*asr, lm, pass_options.no_loc, pass_options.with_intrinsic_mods); + std::ofstream outfile ("pass_viz_" + str_i + "_" + passes[i] + ".html"); + outfile << generate_visualize_html(json) << "\n"; + outfile.close(); + } + std::ofstream outfile ("pass_" + str_i + "_" + passes[i] + ".clj"); + outfile << ";; ASR after applying the pass: " << passes[i] + << "\n" << pickle(*asr, false, true, pass_options.with_intrinsic_mods) << "\n"; + outfile.close(); + } + if (pass_options.dump_fortran) { + LCompilers::Result fortran_code = asr_to_fortran(*asr, diagnostics, false, 4); + if (!fortran_code.ok) { + LCOMPILERS_ASSERT(diagnostics.has_error()); + throw LCompilersException("Fortran code could not be generated after pass: " + + passes[i]); + } + std::string str_i = std::to_string(i+1); + if ( i < 9 ) str_i = "0" + str_i; + std::ofstream outfile ("pass_fortran_" + str_i + "_" + passes[i] + ".f90"); + outfile << "! Fortran code after applying the pass: " << passes[i] + << "\n" << fortran_code.result << "\n"; + outfile.close(); + } +#if defined(WITH_LFORTRAN_ASSERT) + if (!asr_verify(*asr, true, diagnostics)) { + std::cerr << diagnostics.render2(); + throw LCompilersException("Verify failed in the pass: " + + passes[i]); + }; +#endif + if (pass_options.verbose) { + std::cerr << "ASR Pass ends: '" << passes[i] << "'\n"; + } + } + } + void use_optimization_passes() { is_fast = true; } diff --git a/src/libasr/pass/pass_utils.cpp b/src/libasr/pass/pass_utils.cpp index 18fa18c1e4c..4a581ac1445 100644 --- a/src/libasr/pass/pass_utils.cpp +++ b/src/libasr/pass/pass_utils.cpp @@ -112,7 +112,8 @@ namespace LCompilers { arr_expr->base.loc, arr_expr, args.p, args.size(), ASRUtils::type_get_past_array( - ASRUtils::type_get_past_allocatable(array_ref_type)), + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(array_ref_type))), ASR::arraystorageType::RowMajor, nullptr)); if( perform_cast ) { LCOMPILERS_ASSERT(casted_type != nullptr); @@ -143,7 +144,8 @@ namespace LCompilers { arr_expr->base.loc, arr_expr, args.p, args.size(), ASRUtils::type_get_past_array( - ASRUtils::type_get_past_allocatable(array_ref_type)), + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(array_ref_type))), ASR::arraystorageType::RowMajor, nullptr)); if( perform_cast ) { LCOMPILERS_ASSERT(casted_type != nullptr); @@ -219,9 +221,6 @@ namespace LCompilers { ASR::ttype_t* get_matching_type(ASR::expr_t* sibling, Allocator& al) { ASR::ttype_t* sibling_type = ASRUtils::expr_type(sibling); - if( sibling->type != ASR::exprType::Var ) { - return sibling_type; - } ASR::dimension_t* m_dims; int ndims; PassUtils::get_dim_rank(sibling_type, m_dims, ndims); @@ -252,6 +251,16 @@ namespace LCompilers { ASR::expr_t* create_var(int counter, std::string suffix, const Location& loc, ASR::ttype_t* var_type, Allocator& al, SymbolTable*& current_scope) { + ASR::dimension_t* m_dims = nullptr; + int ndims = 0; + PassUtils::get_dim_rank(var_type, m_dims, ndims); + if( !ASRUtils::is_fixed_size_array(m_dims, ndims) && + !ASRUtils::is_dimension_dependent_only_on_arguments(m_dims, ndims) && + !(ASR::is_a(*var_type) || ASR::is_a(*var_type)) ) { + var_type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, var_type->base.loc, + ASRUtils::type_get_past_allocatable( + ASRUtils::duplicate_type_with_empty_dims(al, var_type)))); + } ASR::expr_t* idx_var = nullptr; std::string str_name = "__libasr__created__var__" + std::to_string(counter) + "_" + suffix; char* idx_var_name = s2c(al, str_name); @@ -546,6 +555,7 @@ namespace LCompilers { ASR::expr_t* create_binop_helper(Allocator &al, const Location &loc, ASR::expr_t* left, ASR::expr_t* right, ASR::binopType op) { ASR::ttype_t* type = ASRUtils::expr_type(left); + ASRUtils::make_ArrayBroadcast_t_util(al, loc, left, right); // TODO: compute `value`: if (ASRUtils::is_integer(*type)) { return ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, left, op, right, type, nullptr)); @@ -860,8 +870,8 @@ namespace LCompilers { ASR::expr_t *c=loop.m_head.m_increment; ASR::expr_t *cond = nullptr; ASR::stmt_t *inc_stmt = nullptr; - ASR::stmt_t *stmt1 = nullptr; - ASR::stmt_t *stmt_add_c = nullptr; + ASR::stmt_t *loop_init_stmt = nullptr; + ASR::stmt_t *stmt_add_c_after_loop = nullptr; if( !a && !b && !c ) { int a_kind = 4; if( loop.m_head.m_v ) { @@ -952,12 +962,12 @@ namespace LCompilers { int a_kind = ASRUtils::extract_kind_from_ttype_t(ASRUtils::expr_type(target)); ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, a_kind)); - stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target, + loop_init_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target, ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr)), nullptr)); if (use_loop_variable_after_loop) { - stmt_add_c = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target, - ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, a, + stmt_add_c_after_loop = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target, + ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr)), nullptr)); } @@ -979,16 +989,16 @@ namespace LCompilers { for (size_t i=0; i result; result.reserve(al, 2); - if( stmt1 ) { - result.push_back(al, stmt1); + if( loop_init_stmt ) { + result.push_back(al, loop_init_stmt); } - result.push_back(al, stmt2); - if (stmt_add_c && use_loop_variable_after_loop) { - result.push_back(al, stmt_add_c); + result.push_back(al, while_loop_stmt); + if (stmt_add_c_after_loop && use_loop_variable_after_loop) { + result.push_back(al, stmt_add_c_after_loop); } return result; diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index dddd00b3b7c..828d6033a9a 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -299,6 +299,8 @@ namespace LCompilers { } void visit_Module(const ASR::Module_t& x) { + SymbolTable *parent_symtab = current_scope; + current_scope = x.m_symtab; ASR::Module_t& xx = const_cast(x); module_dependencies.n = 0; module_dependencies.reserve(al, 1); @@ -311,6 +313,7 @@ namespace LCompilers { xx.n_dependencies = module_dependencies.size(); xx.m_dependencies = module_dependencies.p; fill_module_dependencies = fill_module_dependencies_copy; + current_scope = parent_symtab; } void visit_Variable(const ASR::Variable_t& x) { @@ -341,12 +344,17 @@ namespace LCompilers { SymbolTable* temp_scope = current_scope; if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() && - !ASR::is_a(*asr_owner_sym) && !ASR::is_a(*x.m_name) && - !ASR::is_a(*x.m_name)) { - function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name)); + !ASR::is_a(*x.m_name) && !ASR::is_a(*x.m_name)) { + if (ASR::is_a(*asr_owner_sym) || ASR::is_a(*asr_owner_sym)) { + temp_scope = temp_scope->parent; + if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter()) { + function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name)); + } + } else { + function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name)); + } } } - if( ASR::is_a(*x.m_name) && fill_module_dependencies ) { ASR::ExternalSymbol_t* x_m_name = ASR::down_cast(x.m_name); @@ -365,11 +373,17 @@ namespace LCompilers { } SymbolTable* temp_scope = current_scope; - + if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() && - !ASR::is_a(*asr_owner_sym) && !ASR::is_a(*x.m_name) && - !ASR::is_a(*x.m_name)) { - function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name)); + !ASR::is_a(*x.m_name) && !ASR::is_a(*x.m_name)) { + if (ASR::is_a(*asr_owner_sym) || ASR::is_a(*asr_owner_sym)) { + temp_scope = temp_scope->parent; + if (temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter()) { + function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name)); + } + } else { + function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name)); + } } } @@ -384,10 +398,13 @@ namespace LCompilers { } void visit_BlockCall(const ASR::BlockCall_t& x) { + SymbolTable *parent_symtab = current_scope; ASR::Block_t* block = ASR::down_cast(x.m_m); + current_scope = block->m_symtab; for (size_t i=0; in_body; i++) { visit_stmt(*(block->m_body[i])); } + current_scope = parent_symtab; } void visit_AssociateBlock(const ASR::AssociateBlock_t& x) { @@ -412,7 +429,7 @@ namespace LCompilers { visit_stmt(*x.m_body[i]); } current_scope = parent_symtab; - } + } }; namespace ReplacerUtils { @@ -552,7 +569,8 @@ namespace LCompilers { array_ref_type = ASRUtils::duplicate_type(al, array_ref_type, &empty_dims); ASR::expr_t* array_ref = ASRUtils::EXPR(ASRUtils::make_ArrayItem_t_util(al, arr_var->base.loc, arr_var, args.p, args.size(), - ASRUtils::type_get_past_allocatable(array_ref_type), + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(array_ref_type)), ASR::arraystorageType::RowMajor, nullptr)); if( ASR::is_a(*idoloop->m_values[i]) ) { create_do_loop(al, ASR::down_cast(idoloop->m_values[i]), @@ -734,7 +752,8 @@ namespace LCompilers { target_section->base.base.loc, target_section->m_v, args.p, args.size(), - ASRUtils::type_get_past_allocatable(array_ref_type), + ASRUtils::type_get_past_pointer( + ASRUtils::type_get_past_allocatable(array_ref_type)), ASR::arraystorageType::RowMajor, nullptr)); ASR::expr_t* x_m_args_k = x->m_args[k]; if( perform_cast ) { diff --git a/src/libasr/pass/replace_symbolic.h b/src/libasr/pass/replace_symbolic.h index 7e32aefffcd..a191da0ba4e 100644 --- a/src/libasr/pass/replace_symbolic.h +++ b/src/libasr/pass/replace_symbolic.h @@ -11,4 +11,4 @@ namespace LCompilers { } // namespace LCompilers -#endif // LIBASR_PASS_REPLACE_SYMBOLIC_H \ No newline at end of file +#endif // LIBASR_PASS_REPLACE_SYMBOLIC_H diff --git a/src/libasr/pass/subroutine_from_function.cpp b/src/libasr/pass/subroutine_from_function.cpp index eff6493a064..d7b4c2d8ebb 100644 --- a/src/libasr/pass/subroutine_from_function.cpp +++ b/src/libasr/pass/subroutine_from_function.cpp @@ -122,8 +122,7 @@ class ReplaceFunctionCallWithSubroutineCall: bool& apply_again_): al(al_), result_counter(0), pass_result(pass_result_), resultvar2value(resultvar2value_), result_var(nullptr), - apply_again(apply_again_) - {} + apply_again(apply_again_) {} void replace_FunctionCall(ASR::FunctionCall_t* x) { // The following checks if the name of a function actually @@ -180,6 +179,9 @@ class ReplaceFunctionCallWithSubroutineCall: bool is_allocatable = false; bool is_func_call_allocatable = false; bool is_result_var_allocatable = false; + bool is_created_result_var_type_dependent_on_local_vars = false; + ASR::dimension_t* m_dims_ = nullptr; + size_t n_dims_ = 0; ASR::Function_t *fn = ASR::down_cast(fn_name); { // Assuming the `m_return_var` is appended to the `args`. @@ -192,10 +194,13 @@ class ReplaceFunctionCallWithSubroutineCall: is_result_var_allocatable = ASR::is_a(*ASRUtils::expr_type(result_var_)); is_allocatable = is_func_call_allocatable || is_result_var_allocatable; } - if( is_allocatable ) { + n_dims_ = ASRUtils::extract_dimensions_from_ttype(result_var_type, m_dims_); + is_created_result_var_type_dependent_on_local_vars = !ASRUtils::is_dimension_dependent_only_on_arguments(m_dims_, n_dims_); + if( is_allocatable || is_created_result_var_type_dependent_on_local_vars ) { result_var_type = ASRUtils::duplicate_type_with_empty_dims(al, result_var_type); result_var_type = ASRUtils::TYPE(ASR::make_Allocatable_t( - al, loc, ASRUtils::type_get_past_allocatable(result_var_type))); + al, loc, ASRUtils::type_get_past_allocatable( + ASRUtils::type_get_past_pointer(result_var_type)))); } } @@ -238,8 +243,48 @@ class ReplaceFunctionCallWithSubroutineCall: alloc_arg.m_dims = vec_dims.p; alloc_arg.n_dims = vec_dims.n; vec_alloc.push_back(al, alloc_arg); + Vec to_be_deallocated; + to_be_deallocated.reserve(al, vec_alloc.size()); + for( size_t i = 0; i < vec_alloc.size(); i++ ) { + to_be_deallocated.push_back(al, vec_alloc.p[i].m_a); + } + pass_result.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( + al, loc, to_be_deallocated.p, to_be_deallocated.size()))); pass_result.push_back(al, ASRUtils::STMT(ASR::make_Allocate_t( al, loc, vec_alloc.p, 1, nullptr, nullptr, nullptr))); + } else if( !is_func_call_allocatable && is_created_result_var_type_dependent_on_local_vars ) { + Vec alloc_dims; + alloc_dims.reserve(al, n_dims_); + for( size_t i = 0; i < n_dims_; i++ ) { + ASR::dimension_t alloc_dim; + alloc_dim.loc = loc; + alloc_dim.m_start = make_ConstantWithKind(make_IntegerConstant_t, make_Integer_t, 1, 4, loc); + if( m_dims_[i].m_length ) { + alloc_dim.m_length = m_dims_[i].m_length; + } else { + alloc_dim.m_length = ASRUtils::get_size(result_var, i + 1, al); + } + alloc_dims.push_back(al, alloc_dim); + } + Vec alloc_args; + alloc_args.reserve(al, 1); + ASR::alloc_arg_t alloc_arg; + alloc_arg.loc = loc; + alloc_arg.m_len_expr = nullptr; + alloc_arg.m_type = nullptr; + alloc_arg.m_a = *current_expr; + alloc_arg.m_dims = alloc_dims.p; + alloc_arg.n_dims = alloc_dims.size(); + alloc_args.push_back(al, alloc_arg); + Vec to_be_deallocated; + to_be_deallocated.reserve(al, alloc_args.size()); + for( size_t i = 0; i < alloc_args.size(); i++ ) { + to_be_deallocated.push_back(al, alloc_args.p[i].m_a); + } + pass_result.push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t( + al, loc, to_be_deallocated.p, to_be_deallocated.size()))); + pass_result.push_back(al, ASRUtils::STMT(ASR::make_Allocate_t(al, + loc, alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr))); } Vec s_args; @@ -329,7 +374,8 @@ class ReplaceFunctionCallWithSubroutineCallVisitor: return ; } - if( PassUtils::is_array(x.m_target) ) { + if( PassUtils::is_array(x.m_target) + || ASR::is_a(*x.m_target)) { replacer.result_var = x.m_target; ASR::expr_t* original_value = x.m_value; resultvar2value[replacer.result_var] = original_value; diff --git a/src/libasr/pass/transform_optional_argument_functions.cpp b/src/libasr/pass/transform_optional_argument_functions.cpp index 18330caaa75..7e379fb6b85 100644 --- a/src/libasr/pass/transform_optional_argument_functions.cpp +++ b/src/libasr/pass/transform_optional_argument_functions.cpp @@ -333,7 +333,7 @@ bool fill_new_args(Vec& new_args, Allocator& al, size_t k; bool k_found = false; for( k = 0; k < owning_function->n_args; k++ ) { - if( ASR::down_cast(owning_function->m_args[k])->m_v == + if( ASR::is_a(*x.m_args[i].m_value) && ASR::down_cast(owning_function->m_args[k])->m_v == ASR::down_cast(x.m_args[i].m_value)->m_v ) { k_found = true; break ; diff --git a/src/libasr/pass/where.cpp b/src/libasr/pass/where.cpp index d2dd41bc3a9..65ff983c27f 100644 --- a/src/libasr/pass/where.cpp +++ b/src/libasr/pass/where.cpp @@ -112,6 +112,9 @@ class ReplaceVar : public ASR::BaseExprReplacer *current_expr = new_expr; } + void replace_Array(ASR::Array_t */*x*/) { + // pass + } }; class VarVisitor : public ASR::CallReplacerOnExpressionsVisitor @@ -172,6 +175,11 @@ class VarVisitor : public ASR::CallReplacerOnExpressionsVisitor ASR::expr_t* value = *replacer.current_expr; current_expr = current_expr_copy; this->visit_expr(*x.m_value); + if( !ASRUtils::is_array(ASRUtils::expr_type(target)) ) { + if( ASR::is_a(*value) ) { + value = ASR::down_cast(value)->m_array; + } + } ASR::stmt_t* tmp_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, target, value, nullptr)); pass_result.push_back(al, tmp_stmt); } diff --git a/src/libasr/pickle.cpp b/src/libasr/pickle.cpp index 482a0806842..5c54d6909b9 100644 --- a/src/libasr/pickle.cpp +++ b/src/libasr/pickle.cpp @@ -214,7 +214,6 @@ class ASRJsonVisitor : } dec_indent(); s.append("\n" + indtd); s.append("}"); - s.append(",\n" + indtd); append_location(s, x.base.base.loc.first, x.base.base.loc.last); dec_indent(); s.append("\n" + indtd); s.append("}"); @@ -224,15 +223,16 @@ class ASRJsonVisitor : } }; -std::string pickle_json(ASR::asr_t &asr, LocationManager &lm, bool show_intrinsic_modules) { +std::string pickle_json(ASR::asr_t &asr, LocationManager &lm, bool no_loc, bool show_intrinsic_modules) { ASRJsonVisitor v(lm); v.show_intrinsic_modules = show_intrinsic_modules; + v.no_loc = no_loc; v.visit_asr(asr); return v.get_str(); } -std::string pickle_json(ASR::TranslationUnit_t &asr, LocationManager &lm, bool show_intrinsic_modules) { - return pickle_json((ASR::asr_t &)asr, lm, show_intrinsic_modules); +std::string pickle_json(ASR::TranslationUnit_t &asr, LocationManager &lm, bool no_loc, bool show_intrinsic_modules) { + return pickle_json((ASR::asr_t &)asr, lm, no_loc, show_intrinsic_modules); } } // namespace LCompilers diff --git a/src/libasr/pickle.h b/src/libasr/pickle.h index b66b6774d51..748a59696e7 100644 --- a/src/libasr/pickle.h +++ b/src/libasr/pickle.h @@ -17,8 +17,8 @@ namespace LCompilers { std::string pickle_tree(ASR::TranslationUnit_t &asr, bool colors, bool show_intrinsic_modules=false); // Print Json structure - std::string pickle_json(ASR::asr_t &asr, LocationManager &lm, bool show_intrinsic_modules=false); - std::string pickle_json(ASR::TranslationUnit_t &asr, LocationManager &lm, bool show_intrinsic_modules=false); + std::string pickle_json(ASR::asr_t &asr, LocationManager &lm, bool no_loc, bool show_intrinsic_modules); + std::string pickle_json(ASR::TranslationUnit_t &asr, LocationManager &lm, bool no_loc, bool show_intrinsic_modules); } // namespace LCompilers diff --git a/src/libasr/runtime/lfortran_intrinsics.c b/src/libasr/runtime/lfortran_intrinsics.c index f01cb94550e..5495af7c501 100644 --- a/src/libasr/runtime/lfortran_intrinsics.c +++ b/src/libasr/runtime/lfortran_intrinsics.c @@ -12,6 +12,10 @@ #if defined(_MSC_VER) # include +# include +# define ftruncate _chsize_s +#else +# include #endif #include @@ -157,7 +161,7 @@ void handle_integer(char* format, int val, char** result) { dot_pos++; width = atoi(format + 1); min_width = atoi(dot_pos); - if (min_width > width) { + if (min_width > width && width != 0) { perror("Minimum number of digits cannot be more than the specified width for format.\n"); } } else { @@ -166,7 +170,7 @@ void handle_integer(char* format, int val, char** result) { width = len + sign_width; } } - if (width >= len + sign_width) { + if (width >= len + sign_width || width == 0) { if (min_width > len) { for (int i = 0; i < (width - min_width - sign_width); i++) { *result = append_to_string(*result, " "); @@ -177,7 +181,15 @@ void handle_integer(char* format, int val, char** result) { for (int i = 0; i < (min_width - len); i++) { *result = append_to_string(*result, "0"); } - } else { + } else if (width == 0) { + if (val < 0) { + *result = append_to_string(*result, "-"); + } + for (int i = 0; i < (min_width - len - sign_width); i++) { + *result = append_to_string(*result, "0"); + } + } + else { for (int i = 0; i < (width - len - sign_width); i++) { *result = append_to_string(*result, " "); } @@ -205,14 +217,15 @@ void handle_float(char* format, double val, char** result) { char int_str[64]; sprintf(int_str, "%ld", integer_part); char dec_str[64]; - sprintf(dec_str, "%f", decimal_part); + // TODO: This will work for up to `F65.60` but will fail for: + // print "(F67.62)", 1.23456789101112e-62_8 + sprintf(dec_str, "%.*lf", (60-integer_length), decimal_part); memmove(dec_str,dec_str+2,strlen(dec_str)); char* dot_pos = strchr(format, '.'); + decimal_digits = atoi(++dot_pos); width = atoi(format + 1); if (dot_pos != NULL) { - dot_pos++; - decimal_digits = atoi(dot_pos); if (width == 0) { if (decimal_digits == 0) { width = integer_length + sign_width + 1; @@ -237,12 +250,19 @@ void handle_float(char* format, double val, char** result) { for(int i=0;i 1){ + if (scale > 1) { decimal_digits -= scale - 1; } for (int i = 0; i < spaces; i++) { @@ -330,26 +357,44 @@ void handle_decimal(char* format, double val, int scale, char** result, char* c) for (int k = 0; k < abs(scale); k++) { strcat(formatted_value, "0"); } - if (decimal_digits + scale < strlen(val_str) && val != 0) { - long long t = (long long)round((double)atoll(val_str) / (long long)pow(10, (strlen(val_str) - decimal_digits - scale))); + int zeros = 0; + while(val_str[zeros] == '0') zeros++; + // TODO: figure out a way to round decimals with value < 1e-15 + if (decimal_digits + scale < strlen(val_str) && val != 0 && decimal_digits + scale - zeros<= 15) { + val_str[15] = '\0'; + long long t = (long long)round((long double)atoll(val_str) / (long long)pow(10, (strlen(val_str) - decimal_digits - scale))); sprintf(val_str, "%lld", t); + int index = zeros; + while(index--) strcat(formatted_value, "0"); } - strncat(formatted_value, val_str, decimal_digits + scale); + strncat(formatted_value, val_str, decimal_digits + scale - zeros); } else { - strcat(formatted_value, substring(val_str, 0, scale)); + char* temp = substring(val_str, 0, scale); + strcat(formatted_value, temp); strcat(formatted_value, "."); char* new_str = substring(val_str, scale, strlen(val_str)); - if (decimal_digits < strlen(new_str)) { - long long t = (long long)round((double)atoll(new_str) / (long long) pow(10, (strlen(new_str) - decimal_digits))); + int zeros = 0; + if (decimal_digits < strlen(new_str) && decimal_digits + scale <= 15) { + new_str[15] = '\0'; + zeros = strspn(new_str, "0"); + long long t = (long long)round((long double)atoll(new_str) / (long long) pow(10, (strlen(new_str) - decimal_digits))); sprintf(new_str, "%lld", t); + int index = zeros; + while(index--) { + memmove(new_str + 1, new_str, strlen(new_str)+1); + new_str[0] = '0'; + } } - strcat(formatted_value, substring(new_str, 0, decimal_digits)); + new_str[decimal_digits] = '\0'; + strcat(formatted_value, new_str); + free(new_str); + free(temp); } strcat(formatted_value, c); char exponent[12]; - if (atoi(format + 1) == 0){ + if (atoi(num_pos) == 0) { sprintf(exponent, "%+02d", (integer_length > 0 && integer_part != 0 ? integer_length - scale : decimal)); } else { sprintf(exponent, "%+03d", (integer_length > 0 && integer_part != 0 ? integer_length - scale : decimal)); @@ -373,17 +418,23 @@ void handle_decimal(char* format, double val, int scale, char** result, char* c) } } -char** parse_fortran_format(char* format, int *count) { - char** format_values_2 = NULL; +char** parse_fortran_format(char* format, int *count, int *item_start) { + char** format_values_2 = (char**)malloc((*count + 1) * sizeof(char*)); int format_values_count = *count; int index = 0 , start = 0; while (format[index] != '\0') { - format_values_2 = (char**)realloc(format_values_2, (format_values_count + 1) * sizeof(char*)); + char** ptr = (char**)realloc(format_values_2, (format_values_count + 1) * sizeof(char*)); + if (ptr == NULL) { + perror("Memory allocation failed.\n"); + free(format_values_2); + } else { + format_values_2 = ptr; + } switch (tolower(format[index])) { case ',' : break; case '/' : - format_values_2[format_values_count++] = "/"; + format_values_2[format_values_count++] = substring(format, index, index+1); break; case '"' : start = index++; @@ -413,33 +464,41 @@ char** parse_fortran_format(char* format, int *count) { case 'e' : case 'f' : start = index++; + if(tolower(format[index]) == 's') index++; while (isdigit(format[index])) index++; if (format[index] == '.') index++; while (isdigit(format[index])) index++; format_values_2[format_values_count++] = substring(format, start, index); index--; break; + case '(' : + start = index++; + while (format[index] != ')') index++; + format_values_2[format_values_count++] = substring(format, start, index+1); + *item_start = format_values_count; + break; default : if (isdigit(format[index]) && tolower(format[index+1]) == 'p') { start = index; - if (format[index-1] == '-') { + if (index > 0 && format[index-1] == '-') { start = index - 1; } - index = index + 3; - while (isdigit(format[index])) index++; - if (format[index] == '.') index++; - while (isdigit(format[index])) index++; - format_values_2[format_values_count++] = substring(format, start, index); - index--; + index = index + 1; + format_values_2[format_values_count++] = substring(format, start, index + 1); } else if (isdigit(format[index])) { - char* fmt; start = index; while (isdigit(format[index])) index++; - int repeat = atoi(substring(format, start, index)); + char* repeat_str = substring(format, start, index); + int repeat = atoi(repeat_str); + free(repeat_str); + format_values_2 = (char**)realloc(format_values_2, (format_values_count + repeat + 1) * sizeof(char*)); if (format[index] == '(') { start = index++; while (format[index] != ')') index++; - fmt = substring(format, start, index+1); + *item_start = format_values_count+1; + for (int i = 0; i < repeat; i++) { + format_values_2[format_values_count++] = substring(format, start, index+1); + } } else { start = index++; if (isdigit(format[index])) { @@ -447,11 +506,10 @@ char** parse_fortran_format(char* format, int *count) { if (format[index] == '.') index++; while (isdigit(format[index])) index++; } - fmt = substring(format, start, index); - } - for (int i = 0; i < repeat; i++) { - format_values_2[format_values_count++] = fmt; - format_values_2 = (char**)realloc(format_values_2, (format_values_count + 1) * sizeof(char*)); + for (int i = 0; i < repeat; i++) { + format_values_2[format_values_count++] = substring(format, start, index); + } + index--; } } } @@ -466,59 +524,36 @@ LFORTRAN_API char* _lcompilers_string_format_fortran(int count, const char* form va_list args; va_start(args, format); int len = strlen(format); - char* modified_input_string = (char*)malloc(len * sizeof(char)); - strcpy(modified_input_string,format); + char* modified_input_string = (char*)malloc((len+1) * sizeof(char)); + strncpy(modified_input_string, format, len); + modified_input_string[len] = '\0'; if (format[0] == '(' && format[len-1] == ')') { - modified_input_string = substring(format, 1, len - 1); + memmove(modified_input_string, modified_input_string + 1, strlen(modified_input_string)); + modified_input_string[len-1] = '\0'; } - char** format_values = (char**)malloc(sizeof(char*)); - int format_values_count = 0; - format_values = parse_fortran_format(modified_input_string,&format_values_count); + int format_values_count = 0,item_start_idx=0; + char** format_values = parse_fortran_format(modified_input_string,&format_values_count,&item_start_idx); char* result = (char*)malloc(sizeof(char)); result[0] = '\0'; + int item_start = 0; while (1) { - for (int i = 0; i < format_values_count; i++) { + int scale = 0; + for (int i = item_start; i < format_values_count; i++) { + if(format_values[i] == NULL) continue; char* value = format_values[i]; - if (value[0] == '/') { - // Slash Editing (newlines) - int j = 0; - while (value[j] == '/') { - result = append_to_string(result, "\n"); - j++; - } - value = substring(value, j, strlen(value)); - } - - int newline = 0; - if (value[strlen(value) - 1] == '/') { - // Newlines at the end of the argument - int j = strlen(value) - 1; - while (value[j] == '/') { - newline++; - j--; - } - value = substring(value, 0, strlen(value) - newline); - } - - int scale = 0; - if (isdigit(value[0]) && tolower(value[1]) == 'p') { - // Scale Factor (nP) - scale = atoi(&value[0]); - value = substring(value, 2, strlen(value)); - } else if (value[0] == '-' && isdigit(value[1]) && tolower(value[2]) == 'p') { - scale = atoi(substring(value, 0, 2)); - value = substring(value, 3, strlen(value)); - } - if (value[0] == '(' && value[strlen(value)-1] == ')') { - value = substring(value, 1, strlen(value)-1); - char** new_fmt_val = (char**)malloc(sizeof(char*)); + value[strlen(value)-1] = '\0'; int new_fmt_val_count = 0; - new_fmt_val = parse_fortran_format(value,&new_fmt_val_count); - - format_values = (char**)realloc(format_values, (format_values_count + new_fmt_val_count + 1) * sizeof(char*)); - int totalSize = format_values_count + new_fmt_val_count; + char** new_fmt_val = parse_fortran_format(++value,&new_fmt_val_count,&item_start_idx); + + char** ptr = (char**)realloc(format_values, (format_values_count + new_fmt_val_count + 1) * sizeof(char*)); + if (ptr == NULL) { + perror("Memory allocation failed.\n"); + free(format_values); + } else { + format_values = ptr; + } for (int k = format_values_count - 1; k >= i+1; k--) { format_values[k + new_fmt_val_count] = format_values[k]; } @@ -526,32 +561,47 @@ LFORTRAN_API char* _lcompilers_string_format_fortran(int count, const char* form format_values[i + 1 + k] = new_fmt_val[k]; } format_values_count = format_values_count + new_fmt_val_count; - format_values[i] = ""; + free(format_values[i]); + format_values[i] = NULL; + free(new_fmt_val); continue; } - if ((value[0] == '\"' && value[strlen(value) - 1] == '\"') || + if (value[0] == '/') { + result = append_to_string(result, "\n"); + } else if (isdigit(value[0]) && tolower(value[1]) == 'p') { + // Scale Factor nP + scale = atoi(&value[0]); + } else if (value[0] == '-' && isdigit(value[1]) && tolower(value[2]) == 'p') { + char temp[3] = {value[0],value[1],'\0'}; + scale = atoi(temp); + } else if ((value[0] == '\"' && value[strlen(value) - 1] == '\"') || (value[0] == '\'' && value[strlen(value) - 1] == '\'')) { // String value = substring(value, 1, strlen(value) - 1); result = append_to_string(result, value); + free(value); } else if (tolower(value[0]) == 'a') { // Character Editing (A[n]) - char* str = substring(value, 1, strlen(value)); if ( count == 0 ) break; count--; char* arg = va_arg(args, char*); if (arg == NULL) continue; - if (strlen(str) == 0) { - sprintf(str, "%lu", strlen(arg)); + if (strlen(value) == 1) { + result = append_to_string(result, arg); + } else { + char* str = (char*)malloc((strlen(value)) * sizeof(char)); + memmove(str, value+1, strlen(value)); + int buffer_size = 20; + char* s = (char*)malloc(buffer_size * sizeof(char)); + snprintf(s, buffer_size, "%%%s.%ss", str, str); + char* string = (char*)malloc((atoi(str) + 1) * sizeof(char)); + sprintf(string,s, arg); + result = append_to_string(result, string); + free(str); + free(s); + free(string); } - char* s = (char*)malloc((strlen(str) + 4) * sizeof(char)); - sprintf(s, "%%%s.%ss", str, str); - char* string = (char*)malloc((strlen(arg) + 4) * sizeof(char)); - sprintf(string, s, arg); - result = append_to_string(result, string); - free(s); - free(string); } else if (tolower(value[strlen(value) - 1]) == 'x') { result = append_to_string(result, " "); } else if (tolower(value[0]) == 'i') { @@ -582,19 +632,19 @@ LFORTRAN_API char* _lcompilers_string_format_fortran(int count, const char* form printf("Printing support is not available for %s format.\n",value); } - while (newline != 0) { - result = append_to_string(result, "\n"); - newline--; - } } if ( count > 0 ) { result = append_to_string(result, "\n"); + item_start = item_start_idx; } else { break; } } free(modified_input_string); + for (int i = 0;i= 0 && str[end] == ' ') end--; + return end + 1; +} + int str_compare(char **s1, char **s2) { - int s1_len = strlen(*s1); - int s2_len = strlen(*s2); + int s1_len = strlen_without_trailing_space(*s1); + int s2_len = strlen_without_trailing_space(*s2); int lim = MIN(s1_len, s2_len); int res = 0; int i ; @@ -1862,6 +1929,7 @@ void store_unit_file(int32_t unit_num, FILE* filep, bool unit_file_bin) { } FILE* get_file_pointer_from_unit(int32_t unit_num, bool *unit_file_bin) { + *unit_file_bin = false; for( int i = 0; i <= last_index_used; i++ ) { if( unit_to_file[i].unit == unit_num ) { *unit_file_bin = unit_to_file[i].unit_file_bin; @@ -1903,44 +1971,68 @@ LFORTRAN_API int64_t _lfortran_open(int32_t unit_num, char *f_name, char *status if (form == NULL) { form = "formatted"; } - - if (streql(status, "old") || - streql(status, "new") || - streql(status, "replace") || - streql(status, "scratch") || - streql(status, "unknown")) { - // TODO: status can be one of the above. We need to support it - /* - "old" (file must already exist), If it does not exist, the open operation will fail - "new" (file does not exist and will be created) - "replace" (file will be created, replacing any existing file) - "scratch" (temporary file will be deleted when closed) - "unknown" (it is not known whether the file exists) - */ + bool file_exists[1] = {false}; + _lfortran_inquire(f_name, file_exists, -1, NULL); + char *access_mode = NULL; + /* + STATUS=`specifier` in the OPEN statement + The following are the available specifiers: + * "old" (file must already exist) + * "new" (file does not exist and will be created) + * "scratch" (temporary file will be deleted when closed) + * "replace" (file will be created, replacing any existing file) + * "unknown" (it is not known whether the file exists) + */ + if (streql(status, "old")) { + if (!*file_exists) { + printf("Runtime error: File `%s` does not exists!\nCannot open a " + "file with the `status=old`\n", f_name); + exit(1); + } + access_mode = "r+"; + } else if (streql(status, "new")) { + if (*file_exists) { + printf("Runtime error: File `%s` exists!\nCannot open a file with " + "the `status=new`\n", f_name); + exit(1); + } + access_mode = "w+"; + } else if (streql(status, "replace")) { + access_mode = "w+"; + } else if (streql(status, "unknown")) { + if (!*file_exists) { + FILE *fd = fopen(f_name, "w"); + if (fd) { + fclose(fd); + } + } + access_mode = "r+"; + } else if (streql(status, "scratch")) { + printf("Runtime error: Unhandled type status=`scratch`\n"); + exit(1); } else { - printf("Error: STATUS specifier in OPEN statement has invalid value '%s'\n", status); + printf("Runtime error: STATUS specifier in OPEN statement has " + "invalid value '%s'\n", status); exit(1); } - char *access_mode = NULL; bool unit_file_bin; - if (streql(form, "formatted")) { - access_mode = "r"; unit_file_bin = false; } else if (streql(form, "unformatted")) { + // TODO: Handle unformatted write to a file access_mode = "rb"; unit_file_bin = true; } else { - printf("Error: FORM specifier in OPEN statement has invalid value '%s'\n", status); + printf("Runtime error: FORM specifier in OPEN statement has " + "invalid value '%s'\n", form); exit(1); } - FILE *fd; - fd = fopen(f_name, access_mode); + FILE *fd = fopen(f_name, access_mode); if (!fd) { - printf("Error in opening the file!\n"); + printf("Runtime error: Error in opening the file!\n"); perror(f_name); exit(1); } @@ -1994,6 +2086,29 @@ LFORTRAN_API void _lfortran_rewind(int32_t unit_num) rewind(filep); } +LFORTRAN_API void _lfortran_backspace(int32_t unit_num) +{ + bool unit_file_bin; + FILE* fd = get_file_pointer_from_unit(unit_num, &unit_file_bin); + if( fd == NULL ) { + printf("Specified UNIT %d in BACKSPACE is not created or connected.\n", + unit_num); + exit(1); + } + int n = ftell(fd); + for(int i = n; i >= 0; i --) { + char c = fgetc(fd); + if (i == n) { + // Skip previous record newline + fseek(fd, -3, SEEK_CUR); + continue; + } else if (c == '\n') { + break; + } else { + fseek(fd, -2, SEEK_CUR); + } + } +} LFORTRAN_API void _lfortran_read_int32(int32_t *p, int32_t unit_num) { @@ -2126,6 +2241,10 @@ LFORTRAN_API void _lfortran_read_char(char **p, int32_t unit_num) } else { fscanf(filep, "%s", *p); } + if (streql(*p, "")) { + printf("Runtime error: End of file!\n"); + exit(1); + } } LFORTRAN_API void _lfortran_read_float(float *p, int32_t unit_num) @@ -2286,11 +2405,36 @@ LFORTRAN_API void _lfortran_formatted_read(int32_t unit_num, int32_t* iostat, ch exit(1); } - *iostat = !(fgets(*arg, n, filep) == *arg); + *iostat = !(fgets(*arg, n+1, filep) == *arg); (*arg)[strcspn(*arg, "\n")] = 0; va_end(args); } +LFORTRAN_API void _lfortran_empty_read(int32_t unit_num, int32_t* iostat) { + bool unit_file_bin; + FILE* fp = get_file_pointer_from_unit(unit_num, &unit_file_bin); + if (!fp) { + printf("No file found with given unit\n"); + exit(1); + } + + if (!unit_file_bin) { + // The contents of `c` are ignored + char c = fgetc(fp); + while (c != '\n' && c != EOF) { + c = fgetc(fp); + } + + if (feof(fp)) { + *iostat = -1; + } else if (ferror(fp)) { + *iostat = 1; + } else { + *iostat = 0; + } + } +} + LFORTRAN_API char* _lpython_read(int64_t fd, int64_t n) { char *c = (char *) calloc(n, sizeof(char)); @@ -2304,6 +2448,35 @@ LFORTRAN_API char* _lpython_read(int64_t fd, int64_t n) return c; } +LFORTRAN_API void _lfortran_file_write(int32_t unit_num, const char *format, ...) +{ + bool unit_file_bin; + FILE* filep = get_file_pointer_from_unit(unit_num, &unit_file_bin); + if (!filep) { + filep = stdout; + } + if (unit_file_bin) { + printf("Binary content is not handled by write(..)\n"); + exit(1); + } + va_list args; + va_start(args, format); + vfprintf(filep, format, args); + va_end(args); + + ftruncate(fileno(filep), ftell(filep)); +} + +LFORTRAN_API void _lfortran_string_write(char **str, const char *format, ...) { + va_list args; + va_start(args, format); + char *s = (char *) malloc(strlen(*str)*sizeof(char)); + vsprintf(s, format, args); + _lfortran_strcpy(str, s, 0); + free(s); + va_end(args); +} + LFORTRAN_API void _lpython_close(int64_t fd) { if (fclose((FILE*)fd) != 0) @@ -2398,8 +2571,14 @@ struct Stacktrace get_stacktrace_addresses() { } char *get_base_name(char *filename) { - size_t start = strrchr(filename, '/')-filename+1; + // Assuming filename always has an extensions size_t end = strrchr(filename, '.')-filename-1; + // Check for directories else start at 0th index + char *slash_idx_ptr = strrchr(filename, '/'); + size_t start = 0; + if (slash_idx_ptr) { + start = slash_idx_ptr - filename+1; + } int nos_of_chars = end - start + 1; char *base_name; if (nos_of_chars < 0) { diff --git a/src/libasr/runtime/lfortran_intrinsics.h b/src/libasr/runtime/lfortran_intrinsics.h index 3ea56e49915..7efaa10b85a 100644 --- a/src/libasr/runtime/lfortran_intrinsics.h +++ b/src/libasr/runtime/lfortran_intrinsics.h @@ -266,6 +266,9 @@ LFORTRAN_API void _lfortran_read_float(float *p, int32_t unit_num); LFORTRAN_API void _lfortran_read_array_float(float *p, int array_size, int32_t unit_num); LFORTRAN_API void _lfortran_read_array_double(double *p, int array_size, int32_t unit_num); LFORTRAN_API void _lfortran_read_char(char **p, int32_t unit_num); +LFORTRAN_API void _lfortran_string_write(char **str, const char *format, ...); +LFORTRAN_API void _lfortran_file_write(int32_t unit_num, const char *format, ...); +LFORTRAN_API void _lfortran_empty_read(int32_t unit_num, int32_t* iostat); LFORTRAN_API void _lpython_close(int64_t fd); LFORTRAN_API void _lfortran_close(int32_t unit_num); LFORTRAN_API int32_t _lfortran_ichar(char *c); diff --git a/src/libasr/string_utils.cpp b/src/libasr/string_utils.cpp index 55e5a6b2a35..b1b1e92aef9 100644 --- a/src/libasr/string_utils.cpp +++ b/src/libasr/string_utils.cpp @@ -37,6 +37,22 @@ char *s2c(Allocator &al, const std::string &s) { return x.c_str(al); } +// Splits the string `s` using the separator `split_string` +std::vector string_split(const std::string &s, const std::string &split_string) +{ + std::vector result; + size_t old_pos = 0; + size_t new_pos; + while ((new_pos = s.find(split_string, old_pos)) != std::string::npos) { + std::string substr = s.substr(old_pos, new_pos-old_pos); + if (substr.size() > 0) result.push_back(substr); + old_pos = new_pos+split_string.size(); + } + result.push_back(s.substr(old_pos)); + return result; +} + +// Splits the string `s` using any space or newline std::vector split(const std::string &s) { std::vector result; diff --git a/src/libasr/string_utils.h b/src/libasr/string_utils.h index 505ce438351..d41e3eb82be 100644 --- a/src/libasr/string_utils.h +++ b/src/libasr/string_utils.h @@ -14,6 +14,7 @@ namespace LCompilers { bool startswith(const std::string &s, const std::string &e); bool endswith(const std::string &s, const std::string &e); std::string to_lower(const std::string &s); +std::vector string_split(const std::string &s, const std::string &split_string); std::vector split(const std::string &s); std::string join(const std::string j, const std::vector &v); std::vector slice(const std::vector &v, diff --git a/src/libasr/utils.h b/src/libasr/utils.h index 8f46ada7a42..5780008bf93 100644 --- a/src/libasr/utils.h +++ b/src/libasr/utils.h @@ -21,14 +21,46 @@ std::string pf2s(Platform); Platform get_platform(); std::string get_unique_ID(); +int visualize_json(std::string &astr_data_json, LCompilers::Platform os); +std::string generate_visualize_html(std::string &astr_data_json); -struct CompilerOptions { +struct PassOptions { std::filesystem::path mod_files_dir; std::vector include_dirs; - std::vector runtime_linker_paths; + std::string run_fun; // for global_stmts pass // TODO: Convert to std::filesystem::path (also change find_and_load_module()) std::string runtime_library_dir; + bool always_run = false; // for unused_functions pass + bool inline_external_symbol_calls = true; // for inline_function_calls pass + int64_t unroll_factor = 32; // for loop_unroll pass + bool fast = false; // is fast flag enabled. + bool verbose = false; // For developer debugging + bool dump_all_passes = false; // For developer debugging + bool dump_fortran = false; // For developer debugging + bool pass_cumulative = false; // Apply passes cumulatively + bool disable_main = false; + bool use_loop_variable_after_loop = false; + bool realloc_lhs = false; + std::vector skip_optimization_func_instantiation; + bool module_name_mangling = false; + bool global_symbols_mangling = false; + bool intrinsic_symbols_mangling = false; + bool all_symbols_mangling = false; + bool bindc_mangling = false; + bool mangle_underscore = false; + bool json = false; + bool no_loc = false; + bool visualize = false; + bool tree = false; + bool with_intrinsic_mods = false; +}; + +struct CompilerOptions { + std::vector runtime_linker_paths; + + // TODO: Convert to std::filesystem::path (also change find_and_load_module()) + PassOptions po; bool fixed_form = false; bool interactive = false; @@ -61,20 +93,13 @@ struct CompilerOptions { std::string arg_o = ""; bool emit_debug_info = false; bool emit_debug_line_column = false; - bool verbose = false; - bool dump_all_passes = false; - bool pass_cumulative = false; bool enable_cpython = false; bool enable_symengine = false; bool link_numpy = false; - bool realloc_lhs = false; - bool module_name_mangling = false; - bool global_symbols_mangling = false; - bool intrinsic_symbols_mangling = false; - bool all_symbols_mangling = false; - bool bindc_mangling = false; - bool mangle_underscore = false; bool run = false; + bool legacy_array_sections = false; + bool ignore_pragma = false; + bool stack_arrays = false; std::vector import_paths; Platform platform; @@ -88,34 +113,4 @@ int initialize(); } // namespace LCompilers -namespace LCompilers { - - struct PassOptions { - std::filesystem::path mod_files_dir; - std::vector include_dirs; - - std::string run_fun; // for global_stmts pass - // TODO: Convert to std::filesystem::path (also change find_and_load_module()) - std::string runtime_library_dir; - bool always_run = false; // for unused_functions pass - bool inline_external_symbol_calls = true; // for inline_function_calls pass - int64_t unroll_factor = 32; // for loop_unroll pass - bool fast = false; // is fast flag enabled. - bool verbose = false; // For developer debugging - bool dump_all_passes = false; // For developer debugging - bool pass_cumulative = false; // Apply passes cumulatively - bool disable_main = false; - bool use_loop_variable_after_loop = false; - bool realloc_lhs = false; - std::vector skip_optimization_func_instantiation; - bool module_name_mangling = false; - bool global_symbols_mangling = false; - bool intrinsic_symbols_mangling = false; - bool all_symbols_mangling = false; - bool bindc_mangling = false; - bool mangle_underscore = false; - }; - -} - #endif // LIBASR_UTILS_H diff --git a/src/libasr/utils2.cpp b/src/libasr/utils2.cpp index 51a2a3c5f4d..075bdb9923e 100644 --- a/src/libasr/utils2.cpp +++ b/src/libasr/utils2.cpp @@ -3,9 +3,11 @@ #include #endif +#include #include #include #include +#include #include #include @@ -61,6 +63,184 @@ bool present(char** const v, size_t n, const std::string name) { return false; } +int visualize_json(std::string &astr_data_json, LCompilers::Platform os) { + using namespace LCompilers; + std::hash hasher; + std::string file_name = "visualize" + std::to_string(hasher(astr_data_json)) + ".html"; + std::ofstream out; + out.open(file_name); + out << LCompilers::generate_visualize_html(astr_data_json); + out.close(); + std::string open_cmd = ""; + switch (os) { + case Linux: open_cmd = "xdg-open"; break; + case Windows: open_cmd = "start"; break; + case macOS_Intel: + case macOS_ARM: open_cmd = "open"; break; + default: + std::cerr << "Unsupported Platform " << pf2s(os) < + + + LCompilers AST/R Visualization + + + + + + + \n"; + out << R"( + + + + + +)"; + return out.str(); +} + std::string pf2s(Platform p) { switch (p) { case (Platform::Linux) : return "Linux"; diff --git a/src/lpython/python_evaluator.cpp b/src/lpython/python_evaluator.cpp index c38fb92fea6..44075e0a847 100644 --- a/src/lpython/python_evaluator.cpp +++ b/src/lpython/python_evaluator.cpp @@ -76,7 +76,7 @@ Result> PythonCompiler::get_llvm3( return res.error; } - if (compiler_options.fast) { + if (compiler_options.po.fast) { e->opt(*m->m_m); } diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ea250d06a87..f602110c2db 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1200,7 +1200,7 @@ class CommonVisitor : public AST::BaseVisitor { Vec args_new; args_new.reserve(al, func->n_args); visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args); - + if (ASRUtils::symbol_parent_symtab(stemp)->get_counter() != current_scope->get_counter()) { ADD_ASR_DEPENDENCIES(current_scope, stemp, dependencies); } @@ -1457,12 +1457,12 @@ class CommonVisitor : public AST::BaseVisitor { + std::to_string(new_function_num); generic_func_subs[new_func_name] = subs; SymbolTable *target_scope = ASRUtils::symbol_parent_symtab(sym); - t = pass_instantiate_symbol(al, context_map, subs, rt_subs, + t = instantiate_symbol(al, context_map, subs, rt_subs, target_scope, target_scope, new_func_name, sym); if (ASR::is_a(*sym)) { ASR::Function_t *f = ASR::down_cast(sym); ASR::Function_t *new_f = ASR::down_cast(t); - t = pass_instantiate_function_body(al, context_map, subs, rt_subs, + t = instantiate_function_body(al, context_map, subs, rt_subs, target_scope, target_scope, new_f, f); } dependencies.erase(s2c(al, func_name)); @@ -3076,7 +3076,7 @@ class CommonVisitor : public AST::BaseVisitor { bool is_packed = false; if( !is_dataclass(x.m_decorator_list, x.n_decorator_list, algined_expr, is_packed) ) { - throw SemanticError("Only dataclass decorated classes and Enum subclasses are supported.", + throw SemanticError("Only dataclass-decorated classes and Enum subclasses are supported.", x.base.base.loc); }