Skip to content

Commit

Permalink
Change solver interface to better work with internal solver states.
Browse files Browse the repository at this point in the history
That means that model enumeration is no longer part of the solver interface, but should be done using the ConfigurationFactory and ConfigurationIterator.
  • Loading branch information
boehmseb committed Nov 8, 2023
1 parent 67463ab commit 8e4b1dc
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 123 deletions.
32 changes: 29 additions & 3 deletions include/vara/Solver/ConfigurationFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ConfigurationIterable {
public:
class ConfigurationIterator {
public:
using iterator_category = std::forward_iterator_tag;
using iterator_category = std::input_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = std::unique_ptr<vara::feature::Configuration>;
using pointer = std::unique_ptr<vara::feature::Configuration> *;
Expand Down Expand Up @@ -90,8 +90,34 @@ class ConfigurationFactory {
std::vector<std::unique_ptr<vara::feature::Configuration>>>
getAllConfigs(feature::FeatureModel &Model,
const vara::solver::SolverType Type = SolverType::Z3) {
auto S = SolverFactory::initializeSolver(Model, Type);
return S->getAllValidConfigurations();
auto V = std::vector<std::unique_ptr<vara::feature::Configuration>>();
for (auto Config : ConfigurationFactory::getConfigIterator(Model, Type)) {
if (!Config) {
return Error(Config.getError());
}
V.emplace_back(Config.extractValue());
}
return V;
}

/// This method returns the number of configurations of the given feature
/// model.
/// Note that this method needs to enumerate all configurations first.
/// If you need to access the configurations afterwards, prefer to call
/// \c getAllConfigs and check the result's size.
///
/// \param Model the given model containing the features and constraints
/// \param Type the type of solver to use
///
/// \returns the number of configurations for the given model
static Result<SolverErrorCode, uint64_t>
getNumConfigs(feature::FeatureModel &Model,
const vara::solver::SolverType Type = SolverType::Z3) {
auto Configs = getAllConfigs(Model, Type);
if (!Configs) {
return Error(Configs.getError());
}
return Configs.extractValue().size();
}

/// This method returns not all but the specified amount of configurations.
Expand Down
19 changes: 0 additions & 19 deletions include/vara/Solver/Solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,6 @@ class Solver {
/// the current constraint system is solvable (\c true) or not (\c false).
virtual Result<SolverErrorCode, bool> hasValidConfigurations() = 0;

/// Returns the number of valid configurations of the current constraint
/// system (i.e., its features and its constraints). In principle, this is a
/// #SAT call (i.e., enumerating all configurations).
///
/// \returns an error if the number of valid configurations can not be
/// retried. This can be the case if there are still constraints left that
/// were not included into the solver because of missing variables.
virtual Result<SolverErrorCode, uint64_t> getNumberValidConfigurations() = 0;

/// Returns the current configuration.
///
/// \returns the current configuration found by the solver an error code in
Expand All @@ -140,16 +131,6 @@ class Solver {
/// unsatisfiable).
virtual Result<SolverErrorCode, std::unique_ptr<vara::feature::Configuration>>
getNextConfiguration() = 0;

/// Returns all valid configurations. In comparison to \c
/// getNumberValidConfigurations, this method returns the configurations
/// instead of a number of configurations.
///
/// \returns an error if an error occurs while retrieving the configurations.
/// Otherwise, it will return the configurations.
virtual Result<SolverErrorCode,
std::vector<std::unique_ptr<vara::feature::Configuration>>>
getAllValidConfigurations() = 0;
};

} // namespace vara::solver
Expand Down
18 changes: 4 additions & 14 deletions include/vara/Solver/Z3Solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,16 @@ class Z3Solver : public Solver {
Result<SolverErrorCode, std::unique_ptr<vara::feature::Configuration>>
getCurrentConfiguration() override;

Result<SolverErrorCode, uint64_t> getNumberValidConfigurations() override;

Result<SolverErrorCode, std::unique_ptr<vara::feature::Configuration>>
getNextConfiguration() override;

Result<SolverErrorCode,
std::vector<std::unique_ptr<vara::feature::Configuration>>>
getAllValidConfigurations() override;

private:
// The Z3SolverConstraintVisitor is a friend class to access the solver and
// the context.
friend class Z3SolverConstraintVisitor;

/// Exclude the current configuration by adding it as a constraint
/// \return an error code in case of error.
Result<SolverErrorCode> excludeCurrentConfiguration();
/// Exclude the current configuration by adding it as a constraint.
void excludeCurrentConfiguration();

/// Processes the constraints of the binary feature and ignores the 'optional'
/// constraint if the feature is in an alternative group.
Expand All @@ -96,11 +89,8 @@ class Z3Solver : public Solver {
/// variables.
std::unique_ptr<z3::solver> Solver;

/// Flag that indicates whether the solver state has been modified by calling
/// \c getNextConfiguration.
/// This is important for functions that want to enumerate all configurations,
/// like \c getAllValidConfigurations or \c getNumberValidConfigurations.
bool Dirty = false;
/// The current model of the SAT solver.
std::optional<z3::model> CurrentModel;
};

/// \brief This class is a visitor to convert the constraints from the
Expand Down
63 changes: 19 additions & 44 deletions lib/Solver/Z3Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ Result<SolverErrorCode> Z3Solver::addMixedConstraint(
}

Result<SolverErrorCode, bool> Z3Solver::hasValidConfigurations() {
// If CurrentModel exists, we heave already modified the solver state, thus,
// the result of this function might be wrong.
if (CurrentModel) {
return Error(ILLEGAL_STATE);
}

if (Solver->check() == z3::sat) {
return Ok(true);
}
Expand All @@ -177,50 +183,14 @@ Result<SolverErrorCode, bool> Z3Solver::hasValidConfigurations() {

Result<SolverErrorCode, std::unique_ptr<vara::feature::Configuration>>
Z3Solver::getNextConfiguration() {
if (Dirty) {
excludeCurrentConfiguration();
} else {
Dirty = true;
}

if (Solver->check() == z3::unsat) {
return UNSAT;
}
CurrentModel = Solver->get_model();
excludeCurrentConfiguration();
return getCurrentConfiguration();
}

Result<SolverErrorCode, uint64_t> Z3Solver::getNumberValidConfigurations() {
if (Dirty) {
return Error(ILLEGAL_STATE);
}

Solver->push();
uint64_t Count = 0;
while (getNextConfiguration()) {
Count++;
}
Solver->pop();
Dirty = false;
return Count;
}

Result<SolverErrorCode,
std::vector<std::unique_ptr<vara::feature::Configuration>>>
Z3Solver::getAllValidConfigurations() {
if (Dirty) {
return Error(ILLEGAL_STATE);
}

Solver->push();
auto Vector = std::vector<std::unique_ptr<vara::feature::Configuration>>();
while (auto Config = getNextConfiguration()) {
Vector.insert(Vector.begin(), Config.extractValue());
}
Solver->pop();
Dirty = false;
return Vector;
}

Result<SolverErrorCode>
Z3Solver::setBinaryFeatureConstraints(const feature::BinaryFeature &Feature,
bool IsInAlternativeGroup) {
Expand All @@ -238,12 +208,15 @@ Z3Solver::setBinaryFeatureConstraints(const feature::BinaryFeature &Feature,
return Ok();
}

Result<SolverErrorCode> Z3Solver::excludeCurrentConfiguration() {
const z3::model M = Solver->get_model();
void Z3Solver::excludeCurrentConfiguration() {
if (!CurrentModel) {
return;
}

z3::expr Expr = Context.bool_val(false);
for (const auto &Entry : OptionToVariableMapping) {
const z3::expr OptionExpr = *Entry.getValue();
const z3::expr Value = M.eval(OptionExpr, true);
const z3::expr Value = CurrentModel->eval(OptionExpr, true);
if (Value.is_bool()) {
if (Value.is_true()) {
Expr = Expr || !OptionExpr;
Expand All @@ -255,17 +228,19 @@ Result<SolverErrorCode> Z3Solver::excludeCurrentConfiguration() {
}
}
Solver->add(Expr);
return Ok();
}

Result<SolverErrorCode, std::unique_ptr<vara::feature::Configuration>>
Z3Solver::getCurrentConfiguration() {
const z3::model M = Solver->get_model();
if (!CurrentModel) {
return getNextConfiguration();
}

auto Config = std::make_unique<vara::feature::Configuration>();

for (const auto &Entry : OptionToVariableMapping) {
const z3::expr OptionExpr = *Entry.getValue();
const z3::expr Value = M.eval(OptionExpr, true);
const z3::expr Value = CurrentModel->eval(OptionExpr, true);
Config->setConfigurationOption(Entry.getKey(),
llvm::StringRef(Value.to_string()));
}
Expand Down
13 changes: 7 additions & 6 deletions unittests/Solver/SolverFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

#include "vara/Feature/ConstraintBuilder.h"
#include "vara/Feature/FeatureModelBuilder.h"
#include "vara/Solver/ConfigurationFactory.h"
#include "gtest/gtest.h"

namespace vara::solver {

TEST(SolverFactory, EmptyZ3SolverTest) {
auto S = SolverFactory::initializeSolver(SolverType::Z3);
auto E = S->getNumberValidConfigurations();
EXPECT_TRUE(E);
EXPECT_EQ(E.extractValue(), 1);
auto I = ConfigurationIterable(std::move(S));
EXPECT_TRUE(*I.begin());
EXPECT_EQ(std::distance(I.begin(), I.end()), 1);
}

TEST(SolverFactory, GeneralZ3Test) {
Expand Down Expand Up @@ -51,9 +52,9 @@ TEST(SolverFactory, GeneralZ3Test) {
auto FM = B.buildFeatureModel();
auto S = SolverFactory::initializeSolver(*FM, SolverType::Z3);

auto E = S->getNumberValidConfigurations();
EXPECT_TRUE(E);
EXPECT_EQ(E.extractValue(), 6 * 63);
auto I = ConfigurationIterable(std::move(S));
EXPECT_TRUE(*I.begin());
EXPECT_EQ(std::distance(I.begin(), I.end()), 6 * 63);
}

} // namespace vara::solver
45 changes: 8 additions & 37 deletions unittests/Solver/Z3Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "z3++.h"

#include "vara/Solver/ConfigurationFactory.h"
#include "vara/Feature/FeatureModelBuilder.h"
#include "gtest/gtest.h"

Expand Down Expand Up @@ -54,10 +55,6 @@ TEST(Z3Solver, AddFeatureObjectTest) {
V = S->hasValidConfigurations();
EXPECT_TRUE(V);
EXPECT_TRUE(V.extractValue());
// Enumerate the solutions
auto Enumerate = S->getNumberValidConfigurations();
EXPECT_TRUE(Enumerate);
EXPECT_EQ(2, Enumerate.extractValue());

E = S->addFeature(*FM->getFeature("B"));
EXPECT_TRUE(E);
Expand All @@ -68,37 +65,10 @@ TEST(Z3Solver, AddFeatureObjectTest) {
EXPECT_FALSE(E);
EXPECT_EQ(SolverErrorCode::ALREADY_PRESENT, E.getError());

// Enumerate the solutions
Enumerate = S->getNumberValidConfigurations();
EXPECT_TRUE(Enumerate);
EXPECT_EQ(6, Enumerate.extractValue());

E = S->addFeature(*FM->getFeature("C"));
EXPECT_TRUE(E);
V = S->hasValidConfigurations();
EXPECT_TRUE(V.extractValue());
Enumerate = S->getNumberValidConfigurations();
EXPECT_TRUE(Enumerate);
EXPECT_EQ(6, Enumerate.extractValue());
}

TEST(Z3Solver, TestAllValidConfigurations) {
std::unique_ptr<Z3Solver> S = Z3Solver::create();
vara::feature::FeatureModelBuilder B;
B.makeRoot("root");
B.makeFeature<vara::feature::BinaryFeature>("Foo", true);
B.addEdge("root", "Foo");
std::vector<int64_t> Values{0, 1};
B.makeFeature<vara::feature::NumericFeature>("Num1", Values);
auto FM = B.buildFeatureModel();

S->addFeature(*FM->getFeature("root"));
S->addFeature(*FM->getFeature("Foo"));
S->addFeature(*FM->getFeature("Num1"));

auto C = S->getAllValidConfigurations();
EXPECT_TRUE(C);
EXPECT_EQ(C.extractValue().size(), 4);
}

TEST(Z3Solver, TestGetNextConfiguration) {
Expand Down Expand Up @@ -158,9 +128,9 @@ TEST(Z3Solver, AddImpliesConstraint) {
S->addFeature(*FM->getFeature("a"));
S->addFeature(*FM->getFeature("b"));
S->addConstraint(*C);
auto E = S->getNumberValidConfigurations();
EXPECT_TRUE(E);
EXPECT_EQ(E.extractValue(), 6);
auto I = ConfigurationIterable(std::move(S));
EXPECT_TRUE(*I.begin());
EXPECT_EQ(std::distance(I.begin(), I.end()), 6);
}

TEST(Z3Solver, AddAlternative) {
Expand Down Expand Up @@ -203,9 +173,10 @@ TEST(Z3Solver, AddAlternative) {
for (const auto &R : FM->relationships()) {
S->addRelationship(*R);
}
auto E = S->getNumberValidConfigurations();
EXPECT_TRUE(E);
EXPECT_EQ(E.extractValue(), 63);

auto I = ConfigurationIterable(std::move(S));
EXPECT_TRUE(*I.begin());
EXPECT_EQ(std::distance(I.begin(), I.end()), 63);
}

} // namespace vara::solver

0 comments on commit 8e4b1dc

Please sign in to comment.