From 6e2de36a3637b673af70fb61651b2bc964ea4d06 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 9 Oct 2023 08:25:22 +0530 Subject: [PATCH 1/2] Supporting assignment through basic_assign --- src/libasr/pass/replace_symbolic.cpp | 60 +++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 59bdf45f81..9bb9ef414a 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -548,6 +548,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::symbol_t* declare_basic_str_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "basic_str"; symbolic_dependencies.push_back(name); @@ -794,7 +833,26 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; - if (ASR::is_a(*x.m_value)) { + if (ASR::is_a(*x.m_value) && ASR::is_a(*ASRUtils::expr_type(x.m_value))) { + ASR::symbol_t *v = ASR::down_cast(x.m_value)->m_v; + if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return; + ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope); + ASR::symbol_t* var_sym = ASR::down_cast(x.m_value)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = x.m_target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = target; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym, + basic_assign_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } else if (ASR::is_a(*x.m_value)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_value); if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); From 1830fc916b2a70fe64caf734aca1a4a6c952fed3 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 9 Oct 2023 08:36:10 +0530 Subject: [PATCH 2/2] Added tests --- integration_tests/symbolics_01.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/integration_tests/symbolics_01.py b/integration_tests/symbolics_01.py index 525cf6ab4c..3187a9005f 100644 --- a/integration_tests/symbolics_01.py +++ b/integration_tests/symbolics_01.py @@ -6,7 +6,10 @@ def main0(): y: S = Symbol('y') x = pi z: S = x + y + x = z + print(x) print(z) + assert(x == z) assert(z == pi + y) assert(z != S(2)*pi + y)