Skip to content

Commit

Permalink
Merge pull request #3 from msoos/bug
Browse files Browse the repository at this point in the history
Fixing sampling issue with UniGen
  • Loading branch information
arijitsh authored Oct 25, 2024
2 parents f4c2227 + ea71501 commit 28ab263
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 48 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,4 @@ install(FILES
install(EXPORT ${STP_EXPORT_NAME} DESTINATION
"${STP_INSTALL_CMAKE_DIR}"
)

4 changes: 2 additions & 2 deletions include/stp/STPManager/UserDefinedFlags.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct UserDefinedFlags
int64_t AIG_rewrites_iterations = 0; // Number of iterations of AIG rewrites.
int64_t bitblast_simplification = 0;
int64_t size_reducing_fixed_point = 0;


bool simplify_to_constants_only = false;

Expand Down Expand Up @@ -141,7 +141,7 @@ struct UserDefinedFlags
int64_t timeout_max_conflicts = -1;
int num_solver_threads = 1;
uint64_t unisamp_seed = 12345;
uint64_t num_samples = 500;
uint64_t num_samples = 10;
uint64_t samples_generated = 0;
int64_t timeout_max_time = -1; // seconds

Expand Down
5 changes: 3 additions & 2 deletions include/stp/Sat/UniSamp.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ class UniSamp : public SATSolver
#endif

{
ApproxMC::AppMC* a;
UniGen::UniG* s;
vector<vector<int>> unigen_models;
ApproxMC::AppMC* appmc;
UniGen::UniG* unigen;
ArjunNS::Arjun* arjun;
uint64_t seed;
uint64_t samples_generated = 0;
Expand Down
2 changes: 2 additions & 0 deletions lib/Sat/ApxMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ bool ApxMC::solve(bool& timeout_expired) // Search without assumptions.
sampling_vars_orig = sampling_vars;
arjun->set_sampl_vars(sampling_vars_orig);
sampling_vars = arjun->run_backwards();
auto empty_sampl_vars = arjun->get_empty_sampl_vars();
const auto ret = arjun->get_fully_simplified_renumbered_cnf(sc);
sampling_vars = ret.sampl_vars;
a->new_vars(ret.nvars);
Expand All @@ -154,6 +155,7 @@ bool ApxMC::solve(bool& timeout_expired) // Search without assumptions.
a->set_sampl_vars(sampling_vars);

auto sol_count = a->count();
sol_count.hashCount += empty_sampl_vars.size();

// use gmp to get the absolute count of solutions
mpz_class result;
Expand Down
96 changes: 52 additions & 44 deletions lib/Sat/UniSamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ using std::vector;
using namespace CMSat;
using namespace UniGen; // namespace in UniGen library

using std::cout;
using std::endl;

namespace stp
{

static vector<vector<int>> unigen_models;

void mycallback(const std::vector<int>& solution, void*)
void mycallback(const std::vector<int>& solution, void* data)
{
unigen_models.push_back(solution);
vector<vector<int>>* unigen_models = (vector<vector<int>>*)data;
/* for (auto s : solution) std::cout << (s>0 ? "1" : "0"); */
/* std::cout << std::endl; */
unigen_models->push_back(solution);
}

void UniSamp::enableRefinement(const bool enable)
Expand All @@ -55,19 +60,18 @@ UniSamp::UniSamp(uint64_t unisamp_seed, uint64_t _samples_needed,
uint64_t _samples_generated)
{

a = new ApproxMC::AppMC;
s = new UniG(a);
appmc = new ApproxMC::AppMC;
unigen = new UniG(appmc);
arjun = new ArjunNS::Arjun;
seed = unisamp_seed;
samples_needed = _samples_needed;
samples_generated = _samples_generated;
// unisamp_ran = false;
s->set_callback(mycallback, NULL);
a->set_verbosity(0);
unigen->set_callback(mycallback, &unigen_models);
appmc->set_verbosity(0);
arjun->set_verbosity(0);
s->set_verbosity(0);

a->set_seed(seed);
unigen->set_verbosity(0);
appmc->set_seed(seed);

// s->log_to_file("stp.cnf");
//s->set_num_threads(num_threads);
Expand All @@ -78,7 +82,7 @@ UniSamp::UniSamp(uint64_t unisamp_seed, uint64_t _samples_needed,

UniSamp::~UniSamp()
{
delete s;
delete unigen;
vector<CMSat::Lit>* real_temp_cl = (vector<CMSat::Lit>*)temp_cl;
delete real_temp_cl;
}
Expand All @@ -95,15 +99,13 @@ void UniSamp::setMaxTime(int64_t _max_time)

bool UniSamp::addClause(const vec_literals& ps) // Add a clause to the solver.
{
// Cryptominisat uses a slightly different vec class.
// Cryptominisat uses a slightly different Lit class too.

vector<CMSat::Lit>& real_temp_cl = *(vector<CMSat::Lit>*)temp_cl;
real_temp_cl.clear();
for (int i = 0; i < ps.size(); i++)
{
real_temp_cl.push_back(CMSat::Lit(var(ps[i]), sign(ps[i])));
}
/* cout << "c Adding clause to arjun " << real_temp_cl << " 0" << endl; */
return arjun->add_clause(real_temp_cl);
}

Expand All @@ -129,62 +131,64 @@ bool UniSamp::solve(bool& timeout_expired) // Search without assumptions.
if (samples_generated > 1)
return true;

std::cout << "c [stp->unigen] UniSamp solving instance with " << a->nVars()
std::cout << "c [stp->unigen] UniSamp solving instance with " << arjun->nVars()
<< " variables." << std::endl;

vector<uint32_t> sampling_vars, sampling_vars_orig;
for (uint32_t i = 0; i < a->nVars(); i++)
sampling_vars.push_back(i);

arjun->set_seed(5);

for (uint32_t i = 0; i < arjun->nVars(); i++) sampling_vars.push_back(i);
sampling_vars_orig = sampling_vars;
arjun->set_sampl_vars(sampling_vars_orig);

const uint32_t orig_num_vars = arjun->nVars();
appmc->new_vars(orig_num_vars);

bool ret = true;
const uint32_t orig_num_vars = arjun->get_orig_num_vars();
a->new_vars(orig_num_vars);
arjun->start_getting_constraints(false);
vector<Lit> clause;
while (ret)
{
while (ret) {
bool is_xor, rhs;
ret = arjun->get_next_constraint(clause, is_xor, rhs);
assert(rhs);
assert(!is_xor);
if (!ret)
break;
if (!ret) break;

bool ok = true;
for (auto l : clause)
{
if (l.var() >= orig_num_vars)
{
for (auto l : clause) {
if (l.var() >= orig_num_vars) {
ok = false;
break;
}
}

if (ok)
{
a->add_clause(clause);
if (ok) {
/* cout << "adding clause to appmc " << clause << endl; */
appmc->add_clause(clause);
}
}
arjun->end_getting_constraints();
sampling_vars = arjun->run_backwards();
auto empty_sampl_vars = arjun->get_empty_sampl_vars();
delete arjun;
a->set_sampl_vars(sampling_vars);

appmc->set_sampl_vars(sampling_vars);

std::cout << "c [unigen->arjun] sampling var size [from arjun] "
<< sampling_vars.size() << " orig size "
<< sampling_vars_orig.size() << "\n";

auto sol_count = a->count();
s->set_full_sampling_vars(sampling_vars_orig);
auto sol_count = appmc->count();
cout << "c Sol count: " << sol_count.cellSolCount
<< "*2**" << (sol_count.hashCount+empty_sampl_vars.size()) << endl;

// std::cout << "c [stp->unigen] ApproxMC got count " << sol_count.cellSolCount
// << "*2**" << sol_count.hashCount << std::endl;

s->sample(&sol_count, samples_needed);
unigen->set_verbosity(0);
unigen->set_verb_sampler_cls(0);
unigen->set_kappa(0.1);
unigen->set_multisample(false);
unigen->set_full_sampling_vars(sampling_vars_orig);
unigen->set_empty_sampling_vars(empty_sampl_vars);

unigen->sample(&sol_count, samples_needed);
unisamp_ran = true;
return true;
}
Expand All @@ -193,25 +197,29 @@ uint8_t UniSamp::modelValue(uint32_t x) const
{
// if (unigen_models[0].size() < sampling_vars.size())
// std::cout << "c [stp->unigen] ERROR! found model size is not large enough\n";
return (unigen_models[samples_generated].at(x) > 0);
if (samples_generated >= unigen_models.size())
{
std::cout << "c [stp->unigen] ERROR! samples_generated: " << samples_generated
<< " but unigen_models.size(): " << unigen_models.size() << std::endl;
exit(-1);
}
return (unigen_models.at(samples_generated).at(x) > 0);
}

uint32_t UniSamp::newVar()
{
a->new_var();
arjun->new_var();
return a->nVars() - 1;
return arjun->nVars() - 1;
}

void UniSamp::setVerbosity(int v)
{
a->set_verbosity(0);
arjun->set_verbosity(0);
}

unsigned long UniSamp::nVars() const
{
return a->nVars();
return arjun->nVars();
}

void UniSamp::printStats() const
Expand Down

0 comments on commit 28ab263

Please sign in to comment.