@@ -626,6 +626,96 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
626
626
return module_scope->get_symbol (name);
627
627
}
628
628
629
+ ASR::symbol_t * declare_basic_eq_function (Allocator& al, const Location& loc, SymbolTable* module_scope) {
630
+ std::string name = " basic_eq" ;
631
+ symbolic_dependencies.push_back (name);
632
+ if (!module_scope->get_symbol (name)) {
633
+ std::string header = " symengine/cwrapper.h" ;
634
+ SymbolTable* fn_symtab = al.make_new <SymbolTable>(module_scope);
635
+
636
+ Vec<ASR::expr_t *> args;
637
+ args.reserve (al, 1 );
638
+ ASR::symbol_t * arg1 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
639
+ al, loc, fn_symtab, s2c (al, " _lpython_return_variable" ), nullptr , 0 , ASR::intentType::ReturnVar,
640
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )),
641
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
642
+ fn_symtab->add_symbol (s2c (al, " _lpython_return_variable" ), arg1);
643
+ ASR::symbol_t * arg2 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
644
+ al, loc, fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
645
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
646
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
647
+ fn_symtab->add_symbol (s2c (al, " x" ), arg2);
648
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg2)));
649
+ ASR::symbol_t * arg3 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
650
+ al, loc, fn_symtab, s2c (al, " y" ), nullptr , 0 , ASR::intentType::In,
651
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
652
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
653
+ fn_symtab->add_symbol (s2c (al, " y" ), arg3);
654
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg3)));
655
+
656
+ Vec<ASR::stmt_t *> body;
657
+ body.reserve (al, 1 );
658
+
659
+ Vec<char *> dep;
660
+ dep.reserve (al, 1 );
661
+
662
+ ASR::expr_t * return_var = ASRUtils::EXPR (ASR::make_Var_t (al, loc, fn_symtab->get_symbol (" _lpython_return_variable" )));
663
+ ASR::asr_t * subrout = ASRUtils::make_Function_t_util (al, loc,
664
+ fn_symtab, s2c (al, name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
665
+ return_var, ASR::abiType::BindC, ASR::accessType::Public,
666
+ ASR::deftypeType::Interface, s2c (al, name), false , false , false ,
667
+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
668
+ ASR::symbol_t * symbol = ASR::down_cast<ASR::symbol_t >(subrout);
669
+ module_scope->add_symbol (s2c (al, name), symbol);
670
+ }
671
+ return module_scope->get_symbol (name);
672
+ }
673
+
674
+ ASR::symbol_t * declare_basic_neq_function (Allocator& al, const Location& loc, SymbolTable* module_scope) {
675
+ std::string name = " basic_neq" ;
676
+ symbolic_dependencies.push_back (name);
677
+ if (!module_scope->get_symbol (name)) {
678
+ std::string header = " symengine/cwrapper.h" ;
679
+ SymbolTable* fn_symtab = al.make_new <SymbolTable>(module_scope);
680
+
681
+ Vec<ASR::expr_t *> args;
682
+ args.reserve (al, 1 );
683
+ ASR::symbol_t * arg1 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
684
+ al, loc, fn_symtab, s2c (al, " _lpython_return_variable" ), nullptr , 0 , ASR::intentType::ReturnVar,
685
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )),
686
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
687
+ fn_symtab->add_symbol (s2c (al, " _lpython_return_variable" ), arg1);
688
+ ASR::symbol_t * arg2 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
689
+ al, loc, fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
690
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
691
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
692
+ fn_symtab->add_symbol (s2c (al, " x" ), arg2);
693
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg2)));
694
+ ASR::symbol_t * arg3 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
695
+ al, loc, fn_symtab, s2c (al, " y" ), nullptr , 0 , ASR::intentType::In,
696
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
697
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
698
+ fn_symtab->add_symbol (s2c (al, " y" ), arg3);
699
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg3)));
700
+
701
+ Vec<ASR::stmt_t *> body;
702
+ body.reserve (al, 1 );
703
+
704
+ Vec<char *> dep;
705
+ dep.reserve (al, 1 );
706
+
707
+ ASR::expr_t * return_var = ASRUtils::EXPR (ASR::make_Var_t (al, loc, fn_symtab->get_symbol (" _lpython_return_variable" )));
708
+ ASR::asr_t * subrout = ASRUtils::make_Function_t_util (al, loc,
709
+ fn_symtab, s2c (al, name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
710
+ return_var, ASR::abiType::BindC, ASR::accessType::Public,
711
+ ASR::deftypeType::Interface, s2c (al, name), false , false , false ,
712
+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
713
+ ASR::symbol_t * symbol = ASR::down_cast<ASR::symbol_t >(subrout);
714
+ module_scope->add_symbol (s2c (al, name), symbol);
715
+ }
716
+ return module_scope->get_symbol (name);
717
+ }
718
+
629
719
ASR::expr_t * process_attributes (Allocator &al, const Location &loc, ASR::expr_t * expr,
630
720
SymbolTable* module_scope) {
631
721
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
@@ -772,6 +862,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
772
862
}
773
863
}
774
864
}
865
+ } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_value )) {
866
+ ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_value );
867
+ if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
868
+ ASR::symbol_t * sym = nullptr ;
869
+ if (s->m_op == ASR::cmpopType::Eq) {
870
+ sym = declare_basic_eq_function (al, x.base .base .loc , module_scope);
871
+ } else {
872
+ sym = declare_basic_neq_function (al, x.base .base .loc , module_scope);
873
+ }
874
+ ASR::expr_t * value1 = handle_argument (al, x.base .base .loc , s->m_left );
875
+ ASR::expr_t * value2 = handle_argument (al, x.base .base .loc , s->m_right );
876
+
877
+ Vec<ASR::call_arg_t > call_args;
878
+ call_args.reserve (al, 1 );
879
+ ASR::call_arg_t call_arg1, call_arg2;
880
+ call_arg1.loc = x.base .base .loc ;
881
+ call_arg1.m_value = value1;
882
+ call_args.push_back (al, call_arg1);
883
+ call_arg2.loc = x.base .base .loc ;
884
+ call_arg2.m_value = value2;
885
+ call_args.push_back (al, call_arg2);
886
+
887
+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, x.base .base .loc ,
888
+ sym, sym, call_args.p , call_args.n , ASRUtils::TYPE (ASR::make_Logical_t (al, x.base .base .loc , 4 )), nullptr , nullptr ));
889
+ ASR::stmt_t * stmt = ASRUtils::STMT (ASR::make_Assignment_t (al, x.base .base .loc , x.m_target , function_call, nullptr ));
890
+ pass_result.push_back (al, stmt);
891
+ }
775
892
}
776
893
}
777
894
@@ -905,6 +1022,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
905
1022
basic_str_sym, basic_str_sym, call_args.p , call_args.n ,
906
1023
ASRUtils::TYPE (ASR::make_Character_t (al, x.base .base .loc , 1 , -2 , nullptr )), nullptr , nullptr ));
907
1024
print_tmp.push_back (function_call);
1025
+ } else if (ASR::is_a<ASR::SymbolicCompare_t>(*val)) {
1026
+ ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(val);
1027
+ if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
1028
+ ASR::symbol_t * sym = nullptr ;
1029
+ if (s->m_op == ASR::cmpopType::Eq) {
1030
+ sym = declare_basic_eq_function (al, x.base .base .loc , module_scope);
1031
+ } else {
1032
+ sym = declare_basic_neq_function (al, x.base .base .loc , module_scope);
1033
+ }
1034
+ ASR::expr_t * value1 = handle_argument (al, x.base .base .loc , s->m_left );
1035
+ ASR::expr_t * value2 = handle_argument (al, x.base .base .loc , s->m_right );
1036
+
1037
+ Vec<ASR::call_arg_t > call_args;
1038
+ call_args.reserve (al, 1 );
1039
+ ASR::call_arg_t call_arg1, call_arg2;
1040
+ call_arg1.loc = x.base .base .loc ;
1041
+ call_arg1.m_value = value1;
1042
+ call_args.push_back (al, call_arg1);
1043
+ call_arg2.loc = x.base .base .loc ;
1044
+ call_arg2.m_value = value2;
1045
+ call_args.push_back (al, call_arg2);
1046
+
1047
+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, x.base .base .loc ,
1048
+ sym, sym, call_args.p , call_args.n , ASRUtils::TYPE (ASR::make_Logical_t (al, x.base .base .loc , 4 )), nullptr , nullptr ));
1049
+ print_tmp.push_back (function_call);
1050
+ }
908
1051
} else {
909
1052
print_tmp.push_back (x.m_values [i]);
910
1053
}
0 commit comments