diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 44ed0ecfa8..4e29c0b082 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -89,27 +89,29 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi this->current_scope = current_scope_copy; // freeing out variables - std::string new_name = "basic_free_stack"; - ASR::symbol_t* basic_free_stack_sym = module_scope->get_symbol(new_name); - Vec<ASR::stmt_t*> func_body; - func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); + if (!symbolic_vars_to_free.empty()) { + std::string new_name = "basic_free_stack"; + ASR::symbol_t* basic_free_stack_sym = module_scope->get_symbol(new_name); + Vec<ASR::stmt_t*> func_body; + func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); + + for (ASR::symbol_t* symbol : symbolic_vars_to_free) { + if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue; + Vec<ASR::call_arg_t> call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = xx.base.base.loc; + call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, symbol)); + call_args.push_back(al, call_arg); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_free_stack_sym, + basic_free_stack_sym, call_args.p, call_args.n, nullptr)); + func_body.push_back(al, stmt); + } - for (ASR::symbol_t* symbol : symbolic_vars_to_free) { - if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue; - Vec<ASR::call_arg_t> call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = xx.base.base.loc; - call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, symbol)); - call_args.push_back(al, call_arg); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_free_stack_sym, - basic_free_stack_sym, call_args.p, call_args.n, nullptr)); - func_body.push_back(al, stmt); + xx.n_body = func_body.size(); + xx.m_body = func_body.p; + symbolic_vars_to_free.clear(); } - - xx.n_body = func_body.size(); - xx.m_body = func_body.p; - symbolic_vars_to_free.clear(); } void visit_Variable(const ASR::Variable_t& x) { @@ -121,7 +123,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); xx.m_type = type1; - symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx)); + if (var_name != "_lpython_return_variable") { + symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx)); + } if(xx.m_intent == ASR::intentType::In){ symbolic_vars_to_omit.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx)); } @@ -1762,6 +1766,30 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi } } } + + void visit_Return(const ASR::Return_t &x) { + if (!symbolic_vars_to_free.empty()){ + SymbolTable* module_scope = current_scope->parent; + // freeing out variables + std::string new_name = "basic_free_stack"; + ASR::symbol_t* basic_free_stack_sym = module_scope->get_symbol(new_name); + + for (ASR::symbol_t* symbol : symbolic_vars_to_free) { + if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue; + Vec<ASR::call_arg_t> call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)); + call_args.push_back(al, call_arg); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_free_stack_sym, + basic_free_stack_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + symbolic_vars_to_free.clear(); + pass_result.push_back(al, ASRUtils::STMT(ASR::make_Return_t(al, x.base.base.loc))); + } + } }; void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit,