diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e9f338b..86a1e97c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -522,19 +522,24 @@ if (NOT NOUNIGEN) else() message(FATAL_ERROR "Cannot find Arjun. Please install it! Exiting.") endif() - find_package(approxmc CONFIG) - if (approxmc_FOUND) + find_package(approxmc CONFIG REQUIRED) + message(STATUS "Found approxmc") message(STATUS "ApproxMC dynamic lib: ${APPROXMC_LIBRARIES}") message(STATUS "ApproxMC static lib: ${APPROXMC_STATIC_LIBRARIES}") message(STATUS "ApproxMC static lib deps: ${APPROXMC_STATIC_LIBRARIES_DEPS}") message(STATUS "ApproxMC include dirs: ${APPROXMC_INCLUDE_DIRS}") - else() - message(FATAL_ERROR "Cannot find ApproxMC. Please install it! Exiting.") - endif() - find_package(cmsgen CONFIG) - if (cmsgen_FOUND) + find_package(sbva CONFIG REQUIRED) + message(STATUS "Found sbva") + message(STATUS "sbva dynamic lib: ${SBVA_LIBRARIES}") + message(STATUS "sbva static lib: ${SBVA_STATIC_LIBRARIES}") + message(STATUS "sbva static lib deps: ${SBVA_STATIC_LIBRARIES_DEPS}") + message(STATUS "sbva include dirs: ${SBVA_INCLUDE_DIRS}") + include_directories(SYSTEM ${SBVA_INCLUDE_DIRS}) + + + find_package(cmsgen CONFIG REQUIRED) message(STATUS "Found cmsgen") message(STATUS "CMSGen dynamic lib: ${CMSGEN_LIBRARIES}") message(STATUS "CMSGen static lib: ${CMSGEN_STATIC_LIBRARIES}") @@ -542,9 +547,6 @@ if (NOT NOUNIGEN) message(STATUS "CMSGen include dirs: ${CMSGEN_INCLUDE_DIRS}") include_directories(SYSTEM ${CMSGEN_INCLUDE_DIRS}) - else() - message(FATAL_ERROR "Cannot find CMSGen. Please install it! Exiting.") - endif() find_package(unigen CONFIG) if (unigen_FOUND AND HAVE_FLAG_STD_CPP11 AND (NOT NOUNIGEN)) diff --git a/include/stp/STPManager/UserDefinedFlags.h b/include/stp/STPManager/UserDefinedFlags.h index e64d1a0d..96ed46ee 100644 --- a/include/stp/STPManager/UserDefinedFlags.h +++ b/include/stp/STPManager/UserDefinedFlags.h @@ -141,14 +141,20 @@ struct UserDefinedFlags int64_t timeout_max_conflicts = -1; int num_solver_threads = 1; uint64_t unisamp_seed = 12345; + uint64_t num_samples = 500; int64_t timeout_max_time = -1; // seconds + /* Counting and Sampling mode options */ + bool sampling_mode = false; + bool counting_mode = false; + bool almost_uniform_sampling = false; + bool uniform_like_sampling = false; + // check the counterexample against the original input to STP bool check_counterexample_flag = false; //This is derived from other settings. bool construct_counterexample_flag = false; - // Available back-end SAT solvers. enum SATSolvers { diff --git a/include/stp/Sat/UniSamp.h b/include/stp/Sat/UniSamp.h index 51a5bfbd..45b58bf0 100644 --- a/include/stp/Sat/UniSamp.h +++ b/include/stp/Sat/UniSamp.h @@ -44,9 +44,9 @@ class SATSolver; namespace stp { #if defined(__GNUC__) || defined(__clang__) - class __attribute__((visibility("default"))) UniSamp : public SATSolver +class __attribute__((visibility("default"))) UniSamp : public SATSolver #else - class UniSamp : public SATSolver +class UniSamp : public SATSolver #endif { @@ -54,6 +54,9 @@ namespace stp UniGen::UniG* s; ArjunNS::Arjun* arjun; uint64_t seed; + uint64_t samples_generated = 0; + uint64_t samples_needed = 0; + bool unisamp_ran = false; public: UniSamp(uint64_t unisamp_seed); @@ -87,17 +90,17 @@ namespace stp virtual lbool false_literal() { return ((uint8_t)-1); } virtual lbool undef_literal() { return ((uint8_t)0); } - uint32_t getFixedCountWithAssumptions(const stp::SATSolver::vec_literals& assumps, const std::unordered_set& literals ); - + uint32_t + getFixedCountWithAssumptions(const stp::SATSolver::vec_literals& assumps, + const std::unordered_set& literals); void solveAndDump(); - private: void* temp_cl; int64_t max_confl = 0; int64_t max_time = 0; // seconds }; -} +} // namespace stp #endif diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index b425e684..416d0879 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -142,6 +142,7 @@ if (USE_UNIGEN) set(stp_link_libs ${stp_link_libs} ${UNIGEN_STATIC_LIBRARIES} + ${SBVA_STATIC_LIBRARIES} ${CMSGEN_STATIC_LIBRARIES} ${APPROXMC_STATIC_LIBRARIES} ${APPROXMC_STATIC_LIBRARIES_DEPS} @@ -149,6 +150,7 @@ if (USE_UNIGEN) else() set(stp_link_libs ${stp_link_libs} + ${SBVA_LIBRARIES} ${ARJUN_LIBRARIES} ${LOUVAIN_COMMUNITIES_LIBRARIES} ${APPROXMC_LIBRARIES} diff --git a/lib/Interface/cpp_interface.cpp b/lib/Interface/cpp_interface.cpp index d6b7e1cd..b5b2fba9 100644 --- a/lib/Interface/cpp_interface.cpp +++ b/lib/Interface/cpp_interface.cpp @@ -450,10 +450,11 @@ void Cpp_interface::printStatus() // Does some simple caching of prior results. void Cpp_interface::checkSat(const ASTVec& assertionsSMT2) { + static uint64_t samples_generated = 0; if (ignoreCheckSatRequest) return; - bm.GetRunTimes()->stop(RunTimes::Parsing); + // bm.GetRunTimes()->stop(RunTimes::Parsing); checkInvariant(); assert(assertionsSMT2.size() == cache.size()); @@ -478,8 +479,9 @@ void Cpp_interface::checkSat(const ASTVec& assertionsSMT2) // unsat. If it was sat, // we've stored the result (but not the model), so we can shortcut and return // what we know. + // Do not use the shortcut, if we are in sampling mode. if (!((last_run.result == SOLVER_SATISFIABLE) || - last_run.result == SOLVER_UNSATISFIABLE)) + last_run.result == SOLVER_UNSATISFIABLE) || bm.UserFlags.sampling_mode) { resetSolver(); @@ -521,6 +523,16 @@ void Cpp_interface::checkSat(const ASTVec& assertionsSMT2) { getModel(); } + if(bm.UserFlags.sampling_mode) + { + getModel(); + samples_generated++; + if(samples_generated < bm.UserFlags.num_samples) + { + bm.UserFlags.unisamp_seed++; + checkSat(assertionsSMT2); + } + } bm.GetRunTimes()->start(RunTimes::Parsing); diff --git a/lib/Sat/UniSamp.cpp b/lib/Sat/UniSamp.cpp index b3d1656d..b4c53a1d 100644 --- a/lib/Sat/UniSamp.cpp +++ b/lib/Sat/UniSamp.cpp @@ -25,23 +25,21 @@ THE SOFTWARE. #include "stp/Sat/UniSamp.h" #include "approxmc/approxmc.h" #include "unigen/unigen.h" -#include #include +#include using std::vector; using namespace CMSat; using namespace UniGen; // namespace in UniGen library - namespace stp { -vector> unigen_models; - +static vector> unigen_models; void mycallback(const std::vector& solution, void*) { - unigen_models.push_back(solution); + unigen_models.push_back(solution); } void UniSamp::enableRefinement(const bool enable) @@ -60,9 +58,13 @@ UniSamp::UniSamp(uint64_t unisamp_seed) s = new UniG(a); arjun = new ArjunNS::Arjun; seed = unisamp_seed; + samples_needed = num_samples; s->set_callback(mycallback, NULL); - a->set_verbosity(1); + a->set_verbosity(0); + arjun->set_verbosity(0); + s->set_verbosity(0); + a->set_seed(seed); // s->log_to_file("stp.cnf"); //s->set_num_threads(num_threads); @@ -88,8 +90,7 @@ void UniSamp::setMaxTime(int64_t _max_time) max_time = _max_time; } -bool UniSamp::addClause( - const vec_literals& ps) // Add a clause to the solver. +bool UniSamp::addClause(const vec_literals& ps) // Add a clause to the solver. { // Cryptominisat uses a slightly different vec class. // Cryptominisat uses a slightly different Lit class too. @@ -100,12 +101,10 @@ bool UniSamp::addClause( { real_temp_cl.push_back(CMSat::Lit(var(ps[i]), sign(ps[i]))); } - arjun->add_clause(real_temp_cl); - return a->add_clause(real_temp_cl); + return arjun->add_clause(real_temp_cl); } -bool UniSamp::okay() - const // FALSE means solver is in a conflicting state +bool UniSamp::okay() const // FALSE means solver is in a conflicting state { //return a->okay(); return true; //TODO AS: implement well @@ -114,42 +113,72 @@ bool UniSamp::okay() bool UniSamp::solve(bool& timeout_expired) // Search without assumptions. { - /* * STP uses -1 for a value of "no timeout" -- this means that we only set the * timeout _in the SAT solver_ if the value is >= 0. This avoids us * accidentally setting a large limit (or one in the past). */ - // CMSat::lbool ret = s->solve(); // TODO AS + samples_generated += 1; + if (unisamp_ran) + return true; + std::cout << "c [stp->unigen] UniSamp solving instance with " << a->nVars() << " variables." << std::endl; - vector sampling_vars, sampling_vars_orig ; - for(uint32_t i = 0; i < a->nVars(); i++) + vector sampling_vars, sampling_vars_orig; + for (uint32_t i = 0; i < a->nVars(); i++) sampling_vars.push_back(i); arjun->set_seed(5); sampling_vars_orig = sampling_vars; - arjun->set_starting_sampling_set(sampling_vars_orig); - sampling_vars = arjun->get_indep_set(); - std::cout << "c [unigen->arjun] sampling var size [from arjun] " << sampling_vars.size() << "\n"; - + arjun->set_sampl_vars(sampling_vars_orig); + bool ret = true; + const uint32_t orig_num_vars = arjun->get_orig_num_vars(); + a->new_vars(orig_num_vars); + arjun->start_getting_constraints(false); + vector clause; + while (ret) + { + bool is_xor, rhs; + ret = arjun->get_next_constraint(clause, is_xor, rhs); + assert(rhs); + assert(!is_xor); + if (!ret) + break; + + bool ok = true; + for (auto l : clause) + { + if (l.var() >= orig_num_vars) + { + ok = false; + break; + } + } + + if (ok) + { + a->add_clause(clause); + } + } + arjun->end_getting_constraints(); + sampling_vars = arjun->run_backwards(); delete arjun; + a->set_sampl_vars(sampling_vars); - //TODO AS: this is debugging as Arjun is not performing correctly - //sampling_vars = sampling_vars_orig; - - a->set_projection_set(sampling_vars); + std::cout << "c [unigen->arjun] sampling var size [from arjun] " + << sampling_vars.size() << "\n"; auto sol_count = a->count(); s->set_full_sampling_vars(sampling_vars_orig); - std::cout << "c [stp->unigen] ApproxMC got count " << sol_count.cellSolCount - << "*2**" << sol_count.hashCount << std::endl; + // std::cout << "c [stp->unigen] ApproxMC got count " << sol_count.cellSolCount + // << "*2**" << sol_count.hashCount << std::endl; - s->sample(&sol_count,10); + s->sample(&sol_count, samples_needed); + unisamp_ran = true; return true; } @@ -157,7 +186,7 @@ uint8_t UniSamp::modelValue(uint32_t x) const { // if (unigen_models[0].size() < sampling_vars.size()) // std::cout << "c [stp->unigen] ERROR! found model size is not large enough\n"; - return (unigen_models[0].at(x) > 0); + return (unigen_models[samples_generated - 1].at(x) > 0); } uint32_t UniSamp::newVar() @@ -169,7 +198,6 @@ uint32_t UniSamp::newVar() void UniSamp::setVerbosity(int v) { - s->set_verbosity(0); a->set_verbosity(0); arjun->set_verbosity(0); } @@ -185,16 +213,16 @@ void UniSamp::printStats() const } void UniSamp::solveAndDump() - { - bool t; - solve(t); - //s->open_file_and_dump_irred_clauses("clauses.txt"); - } - - +{ + bool t; + solve(t); + //s->open_file_and_dump_irred_clauses("clauses.txt"); +} // Count how many literals/bits get fixed subject to the assumptions.. -uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literals& assumps, const std::unordered_set& literals ) +uint32_t UniSamp::getFixedCountWithAssumptions( + const stp::SATSolver::vec_literals& assumps, + const std::unordered_set& literals) { /* TODO AS skip all? const uint64_t conf = 0; // TODO AS: s->get_sum_conflicts(); @@ -203,7 +231,7 @@ uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literal // const CMSat::lbool r = s->simplify(); TODO AS - + // Add the assumptions are clauses. vector& real_temp_cl = *(vector*)temp_cl; for (int i = 0; i < assumps.size(); i++) @@ -222,12 +250,12 @@ uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literal if (literals.find(l.var()) != literals.end()) assigned++; } - - - + + + //std::cerr << assigned << " assignments at end" <= 0); assert(assigned >= static_cast(assumps.size())); @@ -240,6 +268,4 @@ uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literal return assigned; } - - } //end namespace stp diff --git a/tools/stp/main.cpp b/tools/stp/main.cpp index 5950f464..f3bda208 100644 --- a/tools/stp/main.cpp +++ b/tools/stp/main.cpp @@ -28,8 +28,8 @@ THE SOFTWARE. namespace po = boost::program_options; using namespace stp; -using std::cout; using std::cerr; +using std::cout; using std::endl; /******************************************************************** @@ -249,14 +249,21 @@ void ExtraMain::create_options() "(default)" #endif #endif - ) - ("unisamp,u", "use unisamp as solver -- behave as a almost uniform sampler") - ("cmsgen,s", "use cmsgen as solver -- behave as a uniform like sampler") - ("approxmc,c", "use approxmc as solver -- behave as a approximate counter") - ("seed", - po::value(&bm->UserFlags.unisamp_seed) - ->default_value(bm->UserFlags.unisamp_seed), - "Seed for counting and sampling"); + )("unisamp,u", "use unisamp as solver -- behave as a almost " + "uniform sampler")( + "cmsgen,s", + "use cmsgen as solver -- behave as a uniform like sampler")( + "approxmc,c", + "use approxmc as solver -- behave as a approximate counter")( + "seed", + po::value(&bm->UserFlags.unisamp_seed) + ->default_value(bm->UserFlags.unisamp_seed), + "Seed for counting and sampling")( + "num-samples,ns", + po::value(&bm->UserFlags.num_samples) + ->default_value(bm->UserFlags.num_samples), + "Number of samples to generate in case of sampling"); + ; po::options_description refinement_options("Refinement options"); refinement_options.add_options()( @@ -483,14 +490,19 @@ int ExtraMain::parse_options(int argc, char** argv) if (vm.count("unisamp")) { bm->UserFlags.solver_to_use = UserDefinedFlags::UNIGEN_SOLVER; + bm->UserFlags.sampling_mode = true; + bm->UserFlags.almost_uniform_sampling = true; } if (vm.count("cmsgen")) { bm->UserFlags.solver_to_use = UserDefinedFlags::CMSGEN_SOLVER; + bm->UserFlags.sampling_mode = true; + bm->UserFlags.uniform_like_sampling = true; } if (vm.count("approxmc")) { bm->UserFlags.solver_to_use = UserDefinedFlags::APPROXMC_SOLVER; + bm->UserFlags.counting_mode = true; } #endif