Skip to content

Commit

Permalink
updating for multisample
Browse files Browse the repository at this point in the history
  • Loading branch information
arijitsh committed Apr 3, 2024
1 parent f291644 commit 47cc00d
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 72 deletions.
22 changes: 12 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -522,29 +522,31 @@ if (NOT NOUNIGEN)
else()
message(FATAL_ERROR "Cannot find Arjun. Please install it! Exiting.")
endif()
find_package(approxmc CONFIG)
if (approxmc_FOUND)
find_package(approxmc CONFIG REQUIRED)

message(STATUS "Found approxmc")
message(STATUS "ApproxMC dynamic lib: ${APPROXMC_LIBRARIES}")
message(STATUS "ApproxMC static lib: ${APPROXMC_STATIC_LIBRARIES}")
message(STATUS "ApproxMC static lib deps: ${APPROXMC_STATIC_LIBRARIES_DEPS}")
message(STATUS "ApproxMC include dirs: ${APPROXMC_INCLUDE_DIRS}")
else()
message(FATAL_ERROR "Cannot find ApproxMC. Please install it! Exiting.")
endif()

find_package(cmsgen CONFIG)
if (cmsgen_FOUND)
find_package(sbva CONFIG REQUIRED)
message(STATUS "Found sbva")
message(STATUS "sbva dynamic lib: ${SBVA_LIBRARIES}")
message(STATUS "sbva static lib: ${SBVA_STATIC_LIBRARIES}")
message(STATUS "sbva static lib deps: ${SBVA_STATIC_LIBRARIES_DEPS}")
message(STATUS "sbva include dirs: ${SBVA_INCLUDE_DIRS}")
include_directories(SYSTEM ${SBVA_INCLUDE_DIRS})


find_package(cmsgen CONFIG REQUIRED)
message(STATUS "Found cmsgen")
message(STATUS "CMSGen dynamic lib: ${CMSGEN_LIBRARIES}")
message(STATUS "CMSGen static lib: ${CMSGEN_STATIC_LIBRARIES}")
message(STATUS "CMSGen static lib deps: ${CMSGEN_STATIC_LIBRARIES_DEPS}")
message(STATUS "CMSGen include dirs: ${CMSGEN_INCLUDE_DIRS}")
include_directories(SYSTEM ${CMSGEN_INCLUDE_DIRS})

else()
message(FATAL_ERROR "Cannot find CMSGen. Please install it! Exiting.")
endif()

find_package(unigen CONFIG)
if (unigen_FOUND AND HAVE_FLAG_STD_CPP11 AND (NOT NOUNIGEN))
Expand Down
8 changes: 7 additions & 1 deletion include/stp/STPManager/UserDefinedFlags.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,20 @@ struct UserDefinedFlags
int64_t timeout_max_conflicts = -1;
int num_solver_threads = 1;
uint64_t unisamp_seed = 12345;
uint64_t num_samples = 500;
int64_t timeout_max_time = -1; // seconds

/* Counting and Sampling mode options */
bool sampling_mode = false;
bool counting_mode = false;
bool almost_uniform_sampling = false;
bool uniform_like_sampling = false;

// check the counterexample against the original input to STP
bool check_counterexample_flag = false;
//This is derived from other settings.
bool construct_counterexample_flag = false;


// Available back-end SAT solvers.
enum SATSolvers
{
Expand Down
15 changes: 9 additions & 6 deletions include/stp/Sat/UniSamp.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,19 @@ class SATSolver;
namespace stp
{
#if defined(__GNUC__) || defined(__clang__)
class __attribute__((visibility("default"))) UniSamp : public SATSolver
class __attribute__((visibility("default"))) UniSamp : public SATSolver
#else
class UniSamp : public SATSolver
class UniSamp : public SATSolver
#endif

{
ApproxMC::AppMC* a;
UniGen::UniG* s;
ArjunNS::Arjun* arjun;
uint64_t seed;
uint64_t samples_generated = 0;
uint64_t samples_needed = 0;
bool unisamp_ran = false;

public:
UniSamp(uint64_t unisamp_seed);
Expand Down Expand Up @@ -87,17 +90,17 @@ namespace stp
virtual lbool false_literal() { return ((uint8_t)-1); }
virtual lbool undef_literal() { return ((uint8_t)0); }

uint32_t getFixedCountWithAssumptions(const stp::SATSolver::vec_literals& assumps, const std::unordered_set<unsigned>& literals );

uint32_t
getFixedCountWithAssumptions(const stp::SATSolver::vec_literals& assumps,
const std::unordered_set<unsigned>& literals);

void solveAndDump();


private:
void* temp_cl;
int64_t max_confl = 0;
int64_t max_time = 0; // seconds
};
}
} // namespace stp

#endif
2 changes: 2 additions & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,15 @@ if (USE_UNIGEN)
set(stp_link_libs
${stp_link_libs}
${UNIGEN_STATIC_LIBRARIES}
${SBVA_STATIC_LIBRARIES}
${CMSGEN_STATIC_LIBRARIES}
${APPROXMC_STATIC_LIBRARIES}
${APPROXMC_STATIC_LIBRARIES_DEPS}
${UNIGEN_STATIC_LIBRARIES_DEPS})
else()
set(stp_link_libs
${stp_link_libs}
${SBVA_LIBRARIES}
${ARJUN_LIBRARIES}
${LOUVAIN_COMMUNITIES_LIBRARIES}
${APPROXMC_LIBRARIES}
Expand Down
16 changes: 14 additions & 2 deletions lib/Interface/cpp_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,11 @@ void Cpp_interface::printStatus()
// Does some simple caching of prior results.
void Cpp_interface::checkSat(const ASTVec& assertionsSMT2)
{
static uint64_t samples_generated = 0;
if (ignoreCheckSatRequest)
return;

bm.GetRunTimes()->stop(RunTimes::Parsing);
// bm.GetRunTimes()->stop(RunTimes::Parsing);

checkInvariant();
assert(assertionsSMT2.size() == cache.size());
Expand All @@ -478,8 +479,9 @@ void Cpp_interface::checkSat(const ASTVec& assertionsSMT2)
// unsat. If it was sat,
// we've stored the result (but not the model), so we can shortcut and return
// what we know.
// Do not use the shortcut, if we are in sampling mode.
if (!((last_run.result == SOLVER_SATISFIABLE) ||
last_run.result == SOLVER_UNSATISFIABLE))
last_run.result == SOLVER_UNSATISFIABLE) || bm.UserFlags.sampling_mode)
{
resetSolver();

Expand Down Expand Up @@ -521,6 +523,16 @@ void Cpp_interface::checkSat(const ASTVec& assertionsSMT2)
{
getModel();
}
if(bm.UserFlags.sampling_mode)
{
getModel();
samples_generated++;
if(samples_generated < bm.UserFlags.num_samples)
{
bm.UserFlags.unisamp_seed++;
checkSat(assertionsSMT2);
}
}


bm.GetRunTimes()->start(RunTimes::Parsing);
Expand Down
114 changes: 70 additions & 44 deletions lib/Sat/UniSamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,21 @@ THE SOFTWARE.
#include "stp/Sat/UniSamp.h"
#include "approxmc/approxmc.h"
#include "unigen/unigen.h"
#include <unordered_set>
#include <algorithm>
#include <unordered_set>
using std::vector;

using namespace CMSat;
using namespace UniGen; // namespace in UniGen library


namespace stp
{

vector<vector<int>> unigen_models;

static vector<vector<int>> unigen_models;

void mycallback(const std::vector<int>& solution, void*)
{
unigen_models.push_back(solution);
unigen_models.push_back(solution);
}

void UniSamp::enableRefinement(const bool enable)
Expand All @@ -60,9 +58,13 @@ UniSamp::UniSamp(uint64_t unisamp_seed)
s = new UniG(a);
arjun = new ArjunNS::Arjun;
seed = unisamp_seed;
samples_needed = num_samples;

s->set_callback(mycallback, NULL);
a->set_verbosity(1);
a->set_verbosity(0);
arjun->set_verbosity(0);
s->set_verbosity(0);

a->set_seed(seed);
// s->log_to_file("stp.cnf");
//s->set_num_threads(num_threads);
Expand All @@ -88,8 +90,7 @@ void UniSamp::setMaxTime(int64_t _max_time)
max_time = _max_time;
}

bool UniSamp::addClause(
const vec_literals& ps) // Add a clause to the solver.
bool UniSamp::addClause(const vec_literals& ps) // Add a clause to the solver.
{
// Cryptominisat uses a slightly different vec class.
// Cryptominisat uses a slightly different Lit class too.
Expand All @@ -100,12 +101,10 @@ bool UniSamp::addClause(
{
real_temp_cl.push_back(CMSat::Lit(var(ps[i]), sign(ps[i])));
}
arjun->add_clause(real_temp_cl);
return a->add_clause(real_temp_cl);
return arjun->add_clause(real_temp_cl);
}

bool UniSamp::okay()
const // FALSE means solver is in a conflicting state
bool UniSamp::okay() const // FALSE means solver is in a conflicting state
{
//return a->okay();
return true; //TODO AS: implement well
Expand All @@ -114,50 +113,80 @@ bool UniSamp::okay()
bool UniSamp::solve(bool& timeout_expired) // Search without assumptions.
{


/*
* STP uses -1 for a value of "no timeout" -- this means that we only set the
* timeout _in the SAT solver_ if the value is >= 0. This avoids us
* accidentally setting a large limit (or one in the past).
*/


// CMSat::lbool ret = s->solve(); // TODO AS
samples_generated += 1;
if (unisamp_ran)
return true;

std::cout << "c [stp->unigen] UniSamp solving instance with " << a->nVars()
<< " variables." << std::endl;

vector <uint32_t> sampling_vars, sampling_vars_orig ;
for(uint32_t i = 0; i < a->nVars(); i++)
vector<uint32_t> sampling_vars, sampling_vars_orig;
for (uint32_t i = 0; i < a->nVars(); i++)
sampling_vars.push_back(i);

arjun->set_seed(5);

sampling_vars_orig = sampling_vars;
arjun->set_starting_sampling_set(sampling_vars_orig);
sampling_vars = arjun->get_indep_set();
std::cout << "c [unigen->arjun] sampling var size [from arjun] " << sampling_vars.size() << "\n";

arjun->set_sampl_vars(sampling_vars_orig);
bool ret = true;
const uint32_t orig_num_vars = arjun->get_orig_num_vars();
a->new_vars(orig_num_vars);
arjun->start_getting_constraints(false);
vector<Lit> clause;
while (ret)
{
bool is_xor, rhs;
ret = arjun->get_next_constraint(clause, is_xor, rhs);
assert(rhs);
assert(!is_xor);
if (!ret)
break;

bool ok = true;
for (auto l : clause)
{
if (l.var() >= orig_num_vars)
{
ok = false;
break;
}
}

if (ok)
{
a->add_clause(clause);
}
}
arjun->end_getting_constraints();
sampling_vars = arjun->run_backwards();
delete arjun;
a->set_sampl_vars(sampling_vars);

//TODO AS: this is debugging as Arjun is not performing correctly
//sampling_vars = sampling_vars_orig;

a->set_projection_set(sampling_vars);
std::cout << "c [unigen->arjun] sampling var size [from arjun] "
<< sampling_vars.size() << "\n";

auto sol_count = a->count();
s->set_full_sampling_vars(sampling_vars_orig);
std::cout << "c [stp->unigen] ApproxMC got count " << sol_count.cellSolCount
<< "*2**" << sol_count.hashCount << std::endl;
// std::cout << "c [stp->unigen] ApproxMC got count " << sol_count.cellSolCount
// << "*2**" << sol_count.hashCount << std::endl;

s->sample(&sol_count,10);
s->sample(&sol_count, samples_needed);
unisamp_ran = true;
return true;
}

uint8_t UniSamp::modelValue(uint32_t x) const
{
// if (unigen_models[0].size() < sampling_vars.size())
// std::cout << "c [stp->unigen] ERROR! found model size is not large enough\n";
return (unigen_models[0].at(x) > 0);
return (unigen_models[samples_generated - 1].at(x) > 0);
}

uint32_t UniSamp::newVar()
Expand All @@ -169,7 +198,6 @@ uint32_t UniSamp::newVar()

void UniSamp::setVerbosity(int v)
{
s->set_verbosity(0);
a->set_verbosity(0);
arjun->set_verbosity(0);
}
Expand All @@ -185,16 +213,16 @@ void UniSamp::printStats() const
}

void UniSamp::solveAndDump()
{
bool t;
solve(t);
//s->open_file_and_dump_irred_clauses("clauses.txt");
}


{
bool t;
solve(t);
//s->open_file_and_dump_irred_clauses("clauses.txt");
}

// Count how many literals/bits get fixed subject to the assumptions..
uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literals& assumps, const std::unordered_set<unsigned>& literals )
uint32_t UniSamp::getFixedCountWithAssumptions(
const stp::SATSolver::vec_literals& assumps,
const std::unordered_set<unsigned>& literals)
{
/* TODO AS skip all?
const uint64_t conf = 0; // TODO AS: s->get_sum_conflicts();
Expand All @@ -203,7 +231,7 @@ uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literal
// const CMSat::lbool r = s->simplify(); TODO AS
// Add the assumptions are clauses.
vector<CMSat::Lit>& real_temp_cl = *(vector<CMSat::Lit>*)temp_cl;
for (int i = 0; i < assumps.size(); i++)
Expand All @@ -222,12 +250,12 @@ uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literal
if (literals.find(l.var()) != literals.end())
assigned++;
}
//std::cerr << assigned << " assignments at end" <<std::endl;
// The assumptions are each single literals (corresponding to bits) that are true/false.
// The assumptions are each single literals (corresponding to bits) that are true/false.
// so in the result they should be all be set
assert(assumps.size() >= 0);
assert(assigned >= static_cast<uint32_t>(assumps.size()));
Expand All @@ -240,6 +268,4 @@ uint32_t UniSamp::getFixedCountWithAssumptions(const stp::SATSolver::vec_literal
return assigned;
}



} //end namespace stp
Loading

0 comments on commit 47cc00d

Please sign in to comment.