Skip to content

Commit 36fe6cf

Browse files
authored
Merge pull request #2374 from anutosh491/implementing_symbolic_comparison
Added support for comparing symbolic expressions
2 parents 670b12f + 4ec4a76 commit 36fe6cf

File tree

2 files changed

+161
-1
lines changed

2 files changed

+161
-1
lines changed

integration_tests/symbolics_02.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from sympy import Symbol
1+
from sympy import Symbol, pi
22
from lpython import S
33

44
def test_symbolic_operations():
55
x: S = Symbol('x')
66
y: S = Symbol('y')
7+
p1: S = pi
8+
p2: S = pi
79

810
# Addition
911
z: S = x + y
@@ -37,4 +39,19 @@ def test_symbolic_operations():
3739
assert(c == S(0))
3840
print(c)
3941

42+
# Comparison
43+
b1: bool = p1 == p2
44+
print(b1)
45+
assert(b1 == True)
46+
b2: bool = p1 != pi
47+
print(b2)
48+
assert(b2 == False)
49+
b3: bool = p1 != x
50+
print(b3)
51+
assert(b3 == True)
52+
b4: bool = pi == Symbol("x")
53+
print(b4)
54+
assert(b4 == False)
55+
56+
4057
test_symbolic_operations()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,96 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
626626
return module_scope->get_symbol(name);
627627
}
628628

629+
ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
630+
std::string name = "basic_eq";
631+
symbolic_dependencies.push_back(name);
632+
if (!module_scope->get_symbol(name)) {
633+
std::string header = "symengine/cwrapper.h";
634+
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);
635+
636+
Vec<ASR::expr_t*> args;
637+
args.reserve(al, 1);
638+
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
639+
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
640+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
641+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
642+
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
643+
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
644+
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
645+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
646+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
647+
fn_symtab->add_symbol(s2c(al, "x"), arg2);
648+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
649+
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
650+
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
651+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
652+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
653+
fn_symtab->add_symbol(s2c(al, "y"), arg3);
654+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));
655+
656+
Vec<ASR::stmt_t*> body;
657+
body.reserve(al, 1);
658+
659+
Vec<char*> dep;
660+
dep.reserve(al, 1);
661+
662+
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
663+
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
664+
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
665+
return_var, ASR::abiType::BindC, ASR::accessType::Public,
666+
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
667+
false, false, nullptr, 0, false, false, false, s2c(al, header));
668+
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
669+
module_scope->add_symbol(s2c(al, name), symbol);
670+
}
671+
return module_scope->get_symbol(name);
672+
}
673+
674+
ASR::symbol_t* declare_basic_neq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
675+
std::string name = "basic_neq";
676+
symbolic_dependencies.push_back(name);
677+
if (!module_scope->get_symbol(name)) {
678+
std::string header = "symengine/cwrapper.h";
679+
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);
680+
681+
Vec<ASR::expr_t*> args;
682+
args.reserve(al, 1);
683+
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
684+
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
685+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
686+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
687+
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
688+
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
689+
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
690+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
691+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
692+
fn_symtab->add_symbol(s2c(al, "x"), arg2);
693+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
694+
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
695+
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
696+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
697+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
698+
fn_symtab->add_symbol(s2c(al, "y"), arg3);
699+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));
700+
701+
Vec<ASR::stmt_t*> body;
702+
body.reserve(al, 1);
703+
704+
Vec<char*> dep;
705+
dep.reserve(al, 1);
706+
707+
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
708+
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
709+
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
710+
return_var, ASR::abiType::BindC, ASR::accessType::Public,
711+
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
712+
false, false, nullptr, 0, false, false, false, s2c(al, header));
713+
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
714+
module_scope->add_symbol(s2c(al, name), symbol);
715+
}
716+
return module_scope->get_symbol(name);
717+
}
718+
629719
ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr,
630720
SymbolTable* module_scope) {
631721
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
@@ -772,6 +862,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
772862
}
773863
}
774864
}
865+
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_value)) {
866+
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_value);
867+
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
868+
ASR::symbol_t* sym = nullptr;
869+
if (s->m_op == ASR::cmpopType::Eq) {
870+
sym = declare_basic_eq_function(al, x.base.base.loc, module_scope);
871+
} else {
872+
sym = declare_basic_neq_function(al, x.base.base.loc, module_scope);
873+
}
874+
ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left);
875+
ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right);
876+
877+
Vec<ASR::call_arg_t> call_args;
878+
call_args.reserve(al, 1);
879+
ASR::call_arg_t call_arg1, call_arg2;
880+
call_arg1.loc = x.base.base.loc;
881+
call_arg1.m_value = value1;
882+
call_args.push_back(al, call_arg1);
883+
call_arg2.loc = x.base.base.loc;
884+
call_arg2.m_value = value2;
885+
call_args.push_back(al, call_arg2);
886+
887+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
888+
sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr));
889+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
890+
pass_result.push_back(al, stmt);
891+
}
775892
}
776893
}
777894

@@ -905,6 +1022,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
9051022
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
9061023
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
9071024
print_tmp.push_back(function_call);
1025+
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*val)) {
1026+
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(val);
1027+
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
1028+
ASR::symbol_t* sym = nullptr;
1029+
if (s->m_op == ASR::cmpopType::Eq) {
1030+
sym = declare_basic_eq_function(al, x.base.base.loc, module_scope);
1031+
} else {
1032+
sym = declare_basic_neq_function(al, x.base.base.loc, module_scope);
1033+
}
1034+
ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left);
1035+
ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right);
1036+
1037+
Vec<ASR::call_arg_t> call_args;
1038+
call_args.reserve(al, 1);
1039+
ASR::call_arg_t call_arg1, call_arg2;
1040+
call_arg1.loc = x.base.base.loc;
1041+
call_arg1.m_value = value1;
1042+
call_args.push_back(al, call_arg1);
1043+
call_arg2.loc = x.base.base.loc;
1044+
call_arg2.m_value = value2;
1045+
call_args.push_back(al, call_arg2);
1046+
1047+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
1048+
sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr));
1049+
print_tmp.push_back(function_call);
1050+
}
9081051
} else {
9091052
print_tmp.push_back(x.m_values[i]);
9101053
}

0 commit comments

Comments
 (0)