diff --git a/include/ae/AeValSolver.hpp b/include/ae/AeValSolver.hpp index 219199d5b..eb96424eb 100644 --- a/include/ae/AeValSolver.hpp +++ b/include/ae/AeValSolver.hpp @@ -49,7 +49,7 @@ namespace ufo efac(s->getFactory()), z3(efac), smt (z3), - u(efac), + u(efac,z3), fresh_var_ind(0), partitioning_size(0), debug(0) @@ -794,7 +794,8 @@ namespace ufo template static Expr eliminateQuantifiersRepl(Expr fla, Range& vars) { ExprFactory &efac = fla->getFactory(); - SMTUtils u(efac); + EZ3 ez3(efac); + SMTUtils u(efac, ez3); ExprSet complex; findComplexNumerics(fla, complex); ExprMap repls; @@ -831,7 +832,8 @@ namespace ufo inline static Expr abduce (Expr goal, Expr assm) { ExprFactory &efac = goal->getFactory(); - SMTUtils u(efac); + EZ3 ez3(efac); + SMTUtils u(efac, ez3); ExprSet complex; findComplexNumerics(assm, complex); findComplexNumerics(goal, complex); diff --git a/include/ae/SMTUtils.hpp b/include/ae/SMTUtils.hpp index 4e4a14ecd..07bed1978 100644 --- a/include/ae/SMTUtils.hpp +++ b/include/ae/SMTUtils.hpp @@ -15,7 +15,7 @@ namespace ufo private: ExprFactory &efac; - EZ3 z3; + EZ3& z3; ZSolver smt; bool can_get_model; ZSolver::Model* m; @@ -24,14 +24,15 @@ namespace ufo public: - SMTUtils (ExprFactory& _efac) : - efac(_efac), z3(efac), smt (z3), can_get_model(0), m(NULL) {} - SMTUtils (ExprFactory& _efac, unsigned _to) : - efac(_efac), z3(efac), smt (z3, _to), can_get_model(0), m(NULL) {} + SMTUtils (ExprFactory& _efac, EZ3& _z3) : + efac(_efac), z3(_z3), smt (z3), can_get_model(0), m(NULL) {} - SMTUtils (ExprFactory& _efac, ExprVector& _accessors, unsigned _to, bool _bv) : - efac(_efac), z3(efac), smt (z3, _to), can_get_model(0), m(NULL) + SMTUtils (ExprFactory& _efac, EZ3& _z3, unsigned _to) : + efac(_efac), z3(_z3), smt (z3, _to), can_get_model(0), m(NULL) {} + + SMTUtils (ExprFactory& _efac, EZ3& _z3, ExprVector& _accessors, unsigned _to, bool _bv) : + efac(_efac), z3(_z3), smt (z3, _to), can_get_model(0), m(NULL) { approxBV = _bv; for(auto b : _accessors) @@ -42,12 +43,9 @@ namespace ufo ; } - SMTUtils (ExprFactory& _efac, ExprVector& _accessors, unsigned _to, std::vector adts, std::vector adts_seen) : - efac(_efac), z3(efac), smt (z3, _to), can_get_model(0), m(NULL) + SMTUtils (ExprFactory& _efac, EZ3& _z3, ExprVector& _accessors, unsigned _to) : + efac(_efac), z3(_z3), smt (z3, _to), can_get_model(0), m(NULL) { - - z3.adts = adts; - z3.adts_seen = adts_seen; for(auto b : _accessors) if (b->arity() == 3) accessors.insert(b); diff --git a/include/deep/NonlinCHCsolver.hpp b/include/deep/NonlinCHCsolver.hpp index fbe56988d..2c98429d7 100644 --- a/include/deep/NonlinCHCsolver.hpp +++ b/include/deep/NonlinCHCsolver.hpp @@ -71,7 +71,7 @@ namespace ufo NonlinCHCsolver(CHCs &r, map>> & s) : m_efac(r.m_efac), ruleManager(r), - u(m_efac, r.m_z3.getAdtAccessors(), 10000, r.m_z3.adts, r.m_z3.adts_seen), signature(s) {} + u(m_efac, r.m_z3, r.m_z3.getAdtAccessors(), 10000), signature(s) {} bool checkAllOver(bool checkQuery = false) { for (auto &hr: ruleManager.chcs) { diff --git a/include/ufo/Smt/Z3n.hpp b/include/ufo/Smt/Z3n.hpp index f5ecfd222..ef3d16837 100644 --- a/include/ufo/Smt/Z3n.hpp +++ b/include/ufo/Smt/Z3n.hpp @@ -283,9 +283,6 @@ namespace ufo Z3_set_ast_print_mode (ctx, Z3_PRINT_SMTLIB2_COMPLIANT); } - public: - std::vector adts; - std::vector adts_seen; protected: z3::context &get_ctx () { return ctx; } @@ -294,14 +291,14 @@ namespace ufo expr_ast_map seen_expr; z3::ast toAst (Expr e) { - return M::marshal (e, get_ctx (), cache.left, seen_expr, adts, adts_seen); + return M::marshal (e, get_ctx (), cache.left, seen_expr); } Expr toExpr (z3::ast a) { if (!a) return Expr(); // ast_expr_map seen; - auto res = U::unmarshal (a, get_efac (), cache.right, seen_ast, adts_seen, adts, accessors); + auto res = U::unmarshal (a, get_efac (), cache.right, seen_ast, accessors); return res; } @@ -352,7 +349,6 @@ namespace ufo return out.str (); } - ExprVector& getAdtConstructors(){ return adts; } ExprVector& getAdtAccessors(){ return accessors; } template diff --git a/include/ufo/Smt/ZExprConverter.hpp b/include/ufo/Smt/ZExprConverter.hpp index eee0daf73..a6852b9ff 100644 --- a/include/ufo/Smt/ZExprConverter.hpp +++ b/include/ufo/Smt/ZExprConverter.hpp @@ -42,7 +42,7 @@ namespace ufo { template static z3::ast marshal (Expr e, z3::context &ctx, - C &cache, expr_ast_map &seen, std::vector adts, std::vector &adts_seen) + C &cache, expr_ast_map &seen) { assert (e); if (isOpX(e)) return z3::ast (ctx, Z3_mk_true (ctx)); @@ -65,35 +65,52 @@ namespace ufo if (bind::isBVar (e)) { - z3::ast sort (marshal (bind::type (e), ctx, cache, seen, adts, adts_seen)); + z3::ast sort (marshal (bind::type (e), ctx, cache, seen)); res = Z3_mk_bound (ctx, bind::bvarId (e), reinterpret_cast (static_cast (sort))); } - else if (isOpX (e)) - res = reinterpret_cast (Z3_mk_int_sort (ctx)); + else if (isOpX (e)) { + res = reinterpret_cast (Z3_mk_int_sort(ctx)); + printf("Post reinterpret: %s\n", Z3_ast_to_string(ctx, res)); + } else if (isOpX (e)) res = reinterpret_cast (Z3_mk_real_sort(ctx)); else if (isOpX (e)) res = reinterpret_cast (Z3_mk_bool_sort (ctx)); else if (isOpX (e)) { - res = reinterpret_cast (Z3_mk_int_sort (ctx)); -// res = reinterpret_cast (Z3_mk_datatype_sort(ctx, Z3_mk_string_symbol(ctx, lexical_cast(e->left ()).c_str()))); + // res = reinterpret_cast (Z3_mk_int_sort (ctx)); + std::string name = lexical_cast(e->left()); + Z3_symbol z3_name = Z3_mk_string_symbol(ctx, name.c_str()); + // Z3_constructor csts [constructors[name].size()]; + // for(int i = 0; i (typeDt); + printf("Post reinterpret: %s\n", Z3_ast_to_string(ctx, res)); }// GF: hack for now else if (isOpX (e)) { - z3::ast _idx_sort (marshal (e->left (), ctx, cache, seen, adts, adts_seen)); - z3::ast _val_sort (marshal (e->right (), ctx, cache, seen, adts, adts_seen)); - Z3_sort idx_sort = reinterpret_cast + z3::ast _idx_sort (marshal (e->left (), ctx, cache, seen)); + z3::ast _val_sort (marshal (e->right (), ctx, cache, seen)); + Z3_sort idx_sort = reinterpret_cast (static_cast (_idx_sort)); - Z3_sort val_sort = reinterpret_cast + Z3_sort val_sort = reinterpret_cast (static_cast (_val_sort)); - res = reinterpret_cast - (Z3_mk_array_sort (ctx, idx_sort, val_sort)); + res = reinterpret_cast + (Z3_mk_array_sort (ctx, idx_sort, val_sort)); } else if (isOpX (e)) res = reinterpret_cast (Z3_mk_bv_sort (ctx, bv::width (e))); - + else if (isOpX(e)) { z3::sort sort (ctx, @@ -121,7 +138,7 @@ namespace ufo { z3::sort sort (ctx, Z3_mk_bv_sort (ctx, bv::width (e->arg (1)))); const MPZ& num = dynamic_cast (e->arg (0)->op ()); - + std::string val = boost::lexical_cast (num.get ()); res = Z3_mk_numeral (ctx, val.c_str (), sort); } @@ -178,7 +195,7 @@ namespace ufo for (size_t i = 0; i < bind::domainSz (e); ++i) { - z3::ast a (marshal (bind::domainTy (e, i), ctx, cache, seen, adts, adts_seen)); + z3::ast a (marshal (bind::domainTy (e, i), ctx, cache, seen)); pinned.push_back (a); domain [i] = reinterpret_cast (static_cast(a)); } @@ -187,7 +204,7 @@ namespace ufo z3::sort range (ctx, reinterpret_cast (static_cast - (marshal (bind::rangeTy (e), ctx, cache, seen, adts, adts_seen)))); + (marshal (bind::rangeTy (e), ctx, cache, seen)))); Expr fname = bind::fname (e); @@ -213,7 +230,7 @@ namespace ufo z3::func_decl zfdecl (ctx, reinterpret_cast (static_cast - (marshal (bind::fname (e), ctx, cache, seen, adts, adts_seen)))); + (marshal (bind::fname (e), ctx, cache, seen)))); // -- marshall all arguments except for the first one // -- (which is the fdecl) std::vector args (e->arity ()); @@ -224,7 +241,7 @@ namespace ufo for (ENode::args_iterator it = ++ (e->args_begin ()), end = e->args_end (); it != end; ++it) { - z3::ast a (marshal (*it, ctx, cache, seen, adts, adts_seen)); + z3::ast a (marshal (*it, ctx, cache, seen)); pinned_args.push_back (a); args [pos++] = a; } @@ -239,11 +256,11 @@ namespace ufo for (int i = 0; i < e->arity() - 1; i++) vars.push_back(bind::fapp(e->arg(i))); - z3::ast ast (marshal (e->last(), ctx, cache, seen, adts, adts_seen)); //z3.toAst (e->last())); + z3::ast ast (marshal (e->last(), ctx, cache, seen)); //z3.toAst (e->last())); std::vector bound; bound.reserve (boost::size (vars)); for (const Expr &v : vars) - bound.push_back (Z3_to_app (ctx, marshal (v, ctx, cache, seen, adts, adts_seen))); + bound.push_back (Z3_to_app (ctx, marshal (v, ctx, cache, seen))); if (isOpX (e)) res = Z3_mk_forall_const (ctx, 0, @@ -272,49 +289,49 @@ namespace ufo // -- then it's a NEG or UN_MINUS if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_unary_minus(ctx, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_not(ctx, arg)); } if (isOpX (e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_array_default (ctx, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_bvnot(ctx, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_bvneg(ctx, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_bvredand(ctx, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); return z3::ast (ctx, Z3_mk_bvredor(ctx, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); //TODO: SECOND NUMBER IS THE AMOUNT OF BITS IN THE BV ENCODING return z3::ast (ctx, Z3_mk_int2bv(ctx, 64, arg)); } if (isOpX(e)) { - z3::ast arg = marshal (e->left(), ctx, cache, seen, adts, adts_seen); + z3::ast arg = marshal (e->left(), ctx, cache, seen); //TODO: BOOL DESCRIBES IF NUMBER IS UNSIGNED OR NOT return z3::ast (ctx, Z3_mk_bv2int(ctx, arg, true)); } @@ -323,8 +340,8 @@ namespace ufo } else if (arity == 2) { - z3::ast t1 = marshal(e->left(), ctx, cache, seen, adts, adts_seen); - z3::ast t2 = marshal(e->right(), ctx, cache, seen, adts, adts_seen); + z3::ast t1 = marshal(e->left(), ctx, cache, seen); + z3::ast t2 = marshal(e->right(), ctx, cache, seen); Z3_ast args [2] = {t1, t2}; @@ -373,13 +390,13 @@ namespace ufo else if (isOpX