From 5991115630d44ef1d8917038b25307e09dc9dafc Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 1 Nov 2023 14:44:40 +0530 Subject: [PATCH] Added support for visit_Assert through basic_eq --- src/libasr/pass/replace_symbolic.cpp | 34 ++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index f84136ac97..44ed0ecfa8 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1706,16 +1706,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { - ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); - - ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); - left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym); - right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym); - ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, - s->m_op, right_tmp, s->m_type, s->m_value)); + ASR::SymbolicCompare_t* s = ASR::down_cast(x.m_test); + if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { + ASR::symbol_t* sym = nullptr; + if (s->m_op == ASR::cmpopType::Eq) { + sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + } else { + sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + } + ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); + ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = value1; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + sym, sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); - pass_result.push_back(al, assert_stmt); + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, function_call, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } } else if (ASR::is_a(*x.m_test)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_test); if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {